init
This commit is contained in:
0
vllm_vacc/vllm/__init__.py
Normal file
0
vllm_vacc/vllm/__init__.py
Normal file
BIN
vllm_vacc/vllm/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_vacc/vllm/__pycache__/_custom_ops.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/__pycache__/_custom_ops.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_vacc/vllm/__pycache__/config.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/__pycache__/config.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_vacc/vllm/__pycache__/config_manager.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/__pycache__/config_manager.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_vacc/vllm/__pycache__/sequence.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/__pycache__/sequence.cpython-312.pyc
Normal file
Binary file not shown.
95
vllm_vacc/vllm/_custom_ops.py
Normal file
95
vllm_vacc/vllm/_custom_ops.py
Normal file
@@ -0,0 +1,95 @@
|
||||
|
||||
import contextlib
|
||||
import importlib
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch
|
||||
import torch.library
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.scalar_type import ScalarType
|
||||
|
||||
def cutlass_scaled_mm_vacc(a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""
|
||||
`cutlass_scaled_mm` implements a fused version of
|
||||
`output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)`
|
||||
where scale_a * a and scale_b * b are implemented using numpy-style
|
||||
broadcasting.
|
||||
|
||||
In order to support blockwise scaling like found in DeepSeek V3 we also
|
||||
support extended "group" broadcast rules. We extend the numpy-style
|
||||
broadcasting rules with the following rule:
|
||||
"if the extent of a dimension in the source shape is between 1 and
|
||||
corresponding extent in the target shape we repeat each element along
|
||||
that dimension src_shape[dim] // target_shape[dim] times consecutively"
|
||||
example if we have:
|
||||
a = [[1, 2], and target_shape = (2, 4)
|
||||
[3, 4]]
|
||||
then we would expand a to:
|
||||
a = [[1, 1, 2, 2],
|
||||
[3, 3, 4, 4]]
|
||||
currently we only support the case:
|
||||
scale_a.shape * [1, 128] == a.shape
|
||||
scale_b.shape * [128, 128] == b.shape
|
||||
"""
|
||||
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
|
||||
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
|
||||
assert bias is None or bias.shape[0] == b.shape[
|
||||
1] and bias.dtype == out_dtype
|
||||
|
||||
m = a.shape[0]
|
||||
n = b.shape[1]
|
||||
|
||||
if current_platform.is_rocm():
|
||||
triton_scaled_mm_module = importlib.import_module(
|
||||
"vllm.model_executor.layers.quantization.compressed_tensors."
|
||||
"triton_scaled_mm")
|
||||
triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
|
||||
return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
|
||||
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
|
||||
# print('a',a.shape,a.dtype) # torch.Size([8192, 3584]) torch.float8_e4m3fn
|
||||
# print('scale_a',scale_a.shape) #torch.Size([8192, 56])
|
||||
# print('b',b.shape,b.dtype) # torch.Size([3584, 1536]) torch.float8_e4m3fn
|
||||
# print('scale_b',scale_b.shape) #torch.Size([56, 12])
|
||||
|
||||
use_a32_w32 = True #反量化到fp32 计算 matmul
|
||||
|
||||
if use_a32_w32 or (b.shape[1]//scale_b.shape[1] != 128 or
|
||||
a.shape[1]//scale_a.shape[1] != 128 or
|
||||
b.shape[0]//scale_b.shape[0] != 128):
|
||||
# cutlass_scaled_mm 不支持非128的 quant block
|
||||
a1 = a.to(torch.float32).reshape(a.shape[0], scale_a.shape[1], -1)
|
||||
scale_a = scale_a.reshape(scale_a.shape[0], scale_a.shape[1], 1).to(torch.float32)
|
||||
a = (a1*scale_a).reshape(a.shape).contiguous()
|
||||
|
||||
b1 = b.to(torch.float32).reshape(scale_b.shape[0], b.shape[0]//scale_b.shape[0], scale_b.shape[1], b.shape[1]//scale_b.shape[1])
|
||||
scale_b = scale_b.reshape(scale_b.shape[0], 1, scale_b.shape[1], 1).to(torch.float32)
|
||||
b = (b1*scale_b).reshape(b.shape).contiguous()
|
||||
|
||||
out = a@b
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
return out.to(out_dtype)
|
||||
|
||||
torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def concat_and_cache_mla(
|
||||
kv_c: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
scale: torch.Tensor,
|
||||
) -> None:
|
||||
torch.vacc.concat_and_cache_attention(
|
||||
kv_c, k_pe, kv_cache, slot_mapping)
|
||||
0
vllm_vacc/vllm/attention/__init__.py
Normal file
0
vllm_vacc/vllm/attention/__init__.py
Normal file
BIN
vllm_vacc/vllm/attention/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/attention/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
0
vllm_vacc/vllm/attention/backends/__init__.py
Normal file
0
vllm_vacc/vllm/attention/backends/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
0
vllm_vacc/vllm/attention/backends/mla/__init__.py
Normal file
0
vllm_vacc/vllm/attention/backends/mla/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
202
vllm_vacc/vllm/attention/backends/mla/common.py
Normal file
202
vllm_vacc/vllm/attention/backends/mla/common.py
Normal file
@@ -0,0 +1,202 @@
|
||||
|
||||
import functools
|
||||
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Tuple,
|
||||
Type, TypeVar)
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import AttentionLayer
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionState, MLAAttentionImpl)
|
||||
from vllm.attention.backends.mla.common import MLACommonMetadata,triton_attention
|
||||
|
||||
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearBase, RowParallelLinear,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.attention.utils.fa_utils import get_flash_attn_version
|
||||
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
try:
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
is_vllm_fa = True
|
||||
except ImportError:
|
||||
is_vllm_fa = False
|
||||
try:
|
||||
# For rocm use upstream flash attention
|
||||
from vllm.attention.backends.flash_attn import flash_attn_varlen_func
|
||||
except ImportError:
|
||||
flash_attn_varlen_func = None
|
||||
|
||||
T = TypeVar("T", bound="MLACommonMetadata")
|
||||
|
||||
|
||||
class MLACommonImpl():
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
# blocksparse_params: Optional[Dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
# MLA Specific Arguments
|
||||
q_lora_rank: Optional[int],
|
||||
kv_lora_rank: int,
|
||||
qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int,
|
||||
qk_head_dim: int,
|
||||
v_head_dim: int,
|
||||
rotary_emb: RotaryEmbedding,
|
||||
# q_proj should be q_b_proj if q_lora_rank is not None, but from an
|
||||
# attention backend perspective we rely on the layer to pass in the
|
||||
# correct matrix
|
||||
q_proj: ColumnParallelLinear,
|
||||
kv_b_proj: ColumnParallelLinear,
|
||||
o_proj: RowParallelLinear,
|
||||
positions: torch.Tensor = None,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.qk_head_dim = qk_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
|
||||
self.rotary_emb = rotary_emb
|
||||
self.use_yarn_rope = isinstance(rotary_emb,
|
||||
DeepseekScalingRotaryEmbedding)
|
||||
self.q_proj = q_proj
|
||||
self.kv_b_proj = kv_b_proj
|
||||
self.o_proj = o_proj
|
||||
self.positions = positions
|
||||
|
||||
self.triton_fa_func = triton_attention
|
||||
# Handle the differences between the flash_attn_varlen from flash_attn
|
||||
# and the one from vllm_flash_attn. The former is used on RoCM and the
|
||||
# latter has an additional parameter to control FA2 vs FA3
|
||||
self.flash_attn_varlen_func = flash_attn_varlen_func
|
||||
self.vllm_flash_attn_version = get_flash_attn_version()
|
||||
if self.vllm_flash_attn_version is not None:
|
||||
self.flash_attn_varlen_func = \
|
||||
functools.partial(flash_attn_varlen_func,
|
||||
fa_version=self.vllm_flash_attn_version)
|
||||
|
||||
# For MLA the v head dim is smaller than qk head dim so we pad out
|
||||
# v with 0s to match the qk head dim for attention backends that do
|
||||
# not support different headdims
|
||||
# We don't need to pad V if we are on a hopper system with FA3
|
||||
self._pad_v = self.vllm_flash_attn_version is None or not (
|
||||
self.vllm_flash_attn_version == 3
|
||||
and current_platform.get_device_capability()[0] == 9)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
hidden_states_or_q_c: torch.Tensor, # query in unified attn
|
||||
k_c_normed: torch.Tensor, # key in unified attn
|
||||
k_pe: torch.Tensor, # value in unified attn
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: T,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if output is not None:
|
||||
raise NotImplementedError(
|
||||
"output is not yet supported for MLAImplBase")
|
||||
|
||||
# if attn_metadata.is_profile_run and \
|
||||
# attn_metadata.context_chunk_workspace is not None:
|
||||
# # During the profile run try to simulate to worse case output size
|
||||
# # for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context`
|
||||
# # since this can be large
|
||||
# _ = torch.empty(
|
||||
# (attn_metadata.context_chunk_workspace.shape[0],
|
||||
# self.num_heads, self.qk_nope_head_dim + self.v_head_dim),
|
||||
# device=k_c_normed.device,
|
||||
# dtype=k_c_normed.dtype,
|
||||
# )
|
||||
|
||||
has_decode = attn_metadata.decode_metadata is not None
|
||||
has_prefill = attn_metadata.prefill_metadata is not None
|
||||
|
||||
# Restore head dim (for rotary embedding)
|
||||
k_pe = k_pe.unsqueeze(1)
|
||||
# assert hasattr(attn_metadata, "input_positions")
|
||||
if self.positions is not None:
|
||||
positions = self.positions
|
||||
elif hasattr(attn_metadata, "input_positions"):
|
||||
positions = attn_metadata.input_positions
|
||||
else:
|
||||
raise ValueError('no positions')
|
||||
|
||||
|
||||
num_prefill_tokens: int = attn_metadata.num_prefill_tokens
|
||||
|
||||
decode_hs_or_q_c = hidden_states_or_q_c[num_prefill_tokens:]
|
||||
decode_k_pe = k_pe[num_prefill_tokens:]
|
||||
decode_input_positions = \
|
||||
positions[num_prefill_tokens:]
|
||||
|
||||
prefill_hs_or_q_c = hidden_states_or_q_c[:num_prefill_tokens]
|
||||
prefill_k_pe = k_pe[:num_prefill_tokens]
|
||||
prefill_input_positions = \
|
||||
positions[:num_prefill_tokens]
|
||||
prefill_k_c_normed = k_c_normed[:num_prefill_tokens]
|
||||
|
||||
if has_decode:
|
||||
decode_ql_nope, decode_q_pe = \
|
||||
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
|
||||
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
|
||||
decode_input_positions, decode_q_pe, decode_k_pe)
|
||||
|
||||
if has_prefill:
|
||||
prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\
|
||||
.view(-1, self.num_heads, self.qk_head_dim)
|
||||
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
|
||||
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
|
||||
prefill_input_positions, prefill_q_pe, prefill_k_pe)
|
||||
|
||||
# write the latent and rope to kv cache
|
||||
if kv_cache.numel() > 0:
|
||||
ops.concat_and_cache_mla(
|
||||
k_c_normed,
|
||||
k_pe.squeeze(1),
|
||||
kv_cache,
|
||||
attn_metadata.slot_mapping.flatten(),
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
scale=layer._k_scale,
|
||||
)
|
||||
|
||||
# output = torch.empty(attn_metadata.num_prefill_tokens +
|
||||
# attn_metadata.num_decode_tokens,
|
||||
# self.o_proj.output_size,
|
||||
# device=hidden_states_or_q_c.device,
|
||||
# dtype=hidden_states_or_q_c.dtype)
|
||||
if has_prefill:
|
||||
return self._forward_prefill(
|
||||
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
|
||||
attn_metadata)
|
||||
|
||||
if has_decode:
|
||||
return self._forward_decode(
|
||||
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata)
|
||||
|
||||
assert False, "mla forward need prefill or decode function"
|
||||
return None
|
||||
390
vllm_vacc/vllm/attention/backends/mla/utils.py
Normal file
390
vllm_vacc/vllm/attention/backends/mla/utils.py
Normal file
@@ -0,0 +1,390 @@
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Generic, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import QuantizationStrategy
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import (AttentionLayer,
|
||||
AttentionMetadata,
|
||||
MLAAttentionImpl, T)
|
||||
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearBase, RowParallelLinear,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
||||
CompressedTensorsLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsW8A8Fp8)
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
||||
# from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
# apply_fp8_linear_generic, current_platform_fp8_dtype, is_fp8)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
scaled_dequantize, scaled_quantize)
|
||||
import os
|
||||
|
||||
W_Q_W_QR_WUV_WUK_USE_FP8 = True
|
||||
|
||||
class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
|
||||
def is_layer_fp8(layer: LinearBase) -> bool:
|
||||
return isinstance(layer.quant_method, Fp8LinearMethod) or\
|
||||
(isinstance(layer.quant_method, CompressedTensorsLinearMethod)\
|
||||
and isinstance(layer.scheme, CompressedTensorsW8A8Fp8))
|
||||
|
||||
def quantization_scheme_supported(layer: LinearBase) -> bool:
|
||||
return isinstance(layer.quant_method, UnquantizedLinearMethod) or \
|
||||
is_layer_fp8(layer)
|
||||
|
||||
# TODO(lucas) This is very gross, we need a more wide scale refactor of
|
||||
# all the FP8 code with a more standard way of
|
||||
# defining schemes/group-shapes, we should also potentially force
|
||||
# quant_methods to support a decompress function
|
||||
#
|
||||
# returns input_group_shape, weight_group_shape
|
||||
def get_scale_group_shapes_for_fp8(layer: LinearBase) -> \
|
||||
Tuple[Tuple[int, int], Tuple[int, int]]:
|
||||
if isinstance(layer.quant_method, Fp8LinearMethod):
|
||||
if layer.quant_method.block_quant is not None:
|
||||
weight_block_size = \
|
||||
layer.quant_method.quant_config.weight_block_size
|
||||
# per-token-group (1, X), block-quantized (X, Y)
|
||||
return (1, weight_block_size[-1]), weight_block_size
|
||||
else:
|
||||
return (-1, -1), (-1, -1) # per-tensor, per-tensor
|
||||
elif isinstance(layer.quant_method, CompressedTensorsLinearMethod)\
|
||||
and isinstance(layer.scheme, CompressedTensorsW8A8Fp8):
|
||||
# this is hacky but we always assume the for
|
||||
# CompressedTensorsW8A8Fp8 the input is dynamic per-token
|
||||
# we ignore if it is static-per-tensor since we are going to
|
||||
# requantize after later anyways
|
||||
strategy = layer.scheme.strategy
|
||||
if strategy == QuantizationStrategy.TENSOR:
|
||||
return (1, -1), (-1, -1) # per-token, per-tensor
|
||||
elif strategy == QuantizationStrategy.CHANNEL:
|
||||
return (1, -1), (-1, 1) # per-token, per-channel
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"QuantizationStrategy.{strategy} is not supported for "
|
||||
"fp8 MLA, please run with VLLM_MLA_DISABLE=1")
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Can't determine scale group shapes for "
|
||||
f"{layer.quant_method}, please run with VLLM_MLA_DISABLE=1"
|
||||
)
|
||||
|
||||
def get_scales(layer: LinearBase) -> torch.Tensor:
|
||||
if hasattr(layer, "weight_scale_inv"):
|
||||
return layer.weight_scale_inv
|
||||
return layer.weight_scale
|
||||
|
||||
def get_fp8_layer_weight(layer: LinearBase):
|
||||
if is_layer_fp8(layer):
|
||||
if isinstance(layer.quant_method, \
|
||||
CompressedTensorsLinearMethod) and \
|
||||
isinstance(layer.scheme, CompressedTensorsW8A8Fp8):
|
||||
# NOTE(lucas): note sure why but `CompressedTensorsW8A8Fp8`
|
||||
# seems to store weights as (input, output) instead of
|
||||
# (output, input) so we need to transpose
|
||||
weight = layer.weight.T # standardize to (output, input)
|
||||
else:
|
||||
weight = layer.weight
|
||||
_, weight_scale_group_shape = \
|
||||
get_scale_group_shapes_for_fp8(layer)
|
||||
scales = get_scales(layer) # 已经expand过了
|
||||
weight_scale_group_shape=weight_scale_group_shape.copy() #config中读出来的[128,128], 需要 .copy(), 否则会把config改掉
|
||||
|
||||
# 重新校准一下 weight_scale_group_shape
|
||||
if weight.shape[0] // scales.shape[0] != weight_scale_group_shape[0]:
|
||||
weight_scale_group_shape[0] = weight.shape[0] // scales.shape[0]
|
||||
|
||||
if weight.shape[1] // scales.shape[1] != weight_scale_group_shape[1]:
|
||||
weight_scale_group_shape[1] = weight.shape[1] // scales.shape[1]
|
||||
|
||||
return weight, scales
|
||||
else:
|
||||
return layer.weight, None
|
||||
|
||||
def get_fp8_layer_weight_test(layer: LinearBase):
|
||||
if is_layer_fp8(layer):
|
||||
if isinstance(layer.quant_method, \
|
||||
CompressedTensorsLinearMethod) and \
|
||||
isinstance(layer.scheme, CompressedTensorsW8A8Fp8):
|
||||
# NOTE(lucas): note sure why but `CompressedTensorsW8A8Fp8`
|
||||
# seems to store weights as (input, output) instead of
|
||||
# (output, input) so we need to transpose
|
||||
weight = layer.weight.T # standardize to (output, input)
|
||||
else:
|
||||
weight = layer.weight
|
||||
_, weight_scale_group_shape = \
|
||||
get_scale_group_shapes_for_fp8(layer)
|
||||
scales = get_scales(layer) # 已经expand过了
|
||||
weight_scale_group_shape=weight_scale_group_shape.copy() #config中读出来的[128,128], 需要 .copy(), 否则会把config改掉
|
||||
|
||||
# 重新校准一下 weight_scale_group_shape
|
||||
if weight.shape[0] // scales.shape[0] != weight_scale_group_shape[0]:
|
||||
weight_scale_group_shape[0] = weight.shape[0] // scales.shape[0]
|
||||
|
||||
if weight.shape[1] // scales.shape[1] != weight_scale_group_shape[1]:
|
||||
weight_scale_group_shape[1] = weight.shape[1] // scales.shape[1]
|
||||
|
||||
# for test
|
||||
weight = scaled_dequantize(weight, scales, weight_scale_group_shape)
|
||||
# print(f'{weight.shape}, {scales.shape}, {weight_scale_group_shape}')
|
||||
return weight, scales
|
||||
else:
|
||||
return layer.weight, None
|
||||
|
||||
def check_eq(name, tensor0, tensor1):
|
||||
assert tensor0.shape == tensor1.shape
|
||||
isEqual = torch.equal(tensor0.reshape([-1]).float(), tensor1.reshape([-1]).float())
|
||||
print(f"{os.getpid()} check {name} {tensor0.shape} equal: {isEqual}")
|
||||
return isEqual
|
||||
|
||||
def get_and_maybe_dequant_weights(layer: LinearBase):
|
||||
if is_layer_fp8(layer):
|
||||
if isinstance(layer.quant_method, \
|
||||
CompressedTensorsLinearMethod) and \
|
||||
isinstance(layer.scheme, CompressedTensorsW8A8Fp8):
|
||||
# NOTE(lucas): note sure why but `CompressedTensorsW8A8Fp8`
|
||||
# seems to store weights as (input, output) instead of
|
||||
# (output, input) so we need to transpose
|
||||
weight = layer.weight.T # standardize to (output, input)
|
||||
else:
|
||||
weight = layer.weight
|
||||
_, weight_scale_group_shape = \
|
||||
get_scale_group_shapes_for_fp8(layer)
|
||||
scales = get_scales(layer) # 已经expand过了
|
||||
weight_scale_group_shape=weight_scale_group_shape.copy() #config中读出来的[128,128], 需要 .copy(), 否则会把config改掉
|
||||
|
||||
# 重新校准一下 weight_scale_group_shape
|
||||
if weight.shape[0] // scales.shape[0] != weight_scale_group_shape[0]:
|
||||
weight_scale_group_shape[0] = weight.shape[0] // scales.shape[0]
|
||||
|
||||
if weight.shape[1] // scales.shape[1] != weight_scale_group_shape[1]:
|
||||
weight_scale_group_shape[1] = weight.shape[1] // scales.shape[1]
|
||||
|
||||
return scaled_dequantize(weight, scales,
|
||||
weight_scale_group_shape)
|
||||
else:
|
||||
return layer.weight
|
||||
|
||||
if not (quantization_scheme_supported(self.kv_b_proj) and\
|
||||
quantization_scheme_supported(self.q_proj) and\
|
||||
quantization_scheme_supported(self.o_proj)):
|
||||
raise NotImplementedError(
|
||||
"Only FP8 and UnquantizedLinearMethod are supported for MLA"
|
||||
", please run with VLLM_MLA_DISABLE=1")
|
||||
|
||||
weight_dtype = self.kv_b_proj.weight.dtype
|
||||
assert self.o_proj.weight.dtype == weight_dtype
|
||||
assert self.q_proj.weight.dtype == weight_dtype
|
||||
|
||||
if W_Q_W_QR_WUV_WUK_USE_FP8: #and not envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
|
||||
# 512,1024(=4x256)
|
||||
kv_b_proj_weight, kv_b_proj_scale = \
|
||||
[t.T for t in get_fp8_layer_weight(self.kv_b_proj)]
|
||||
|
||||
# kv_b_proj_weight = kv_b_proj_weight.transpose(-1,-2).contiguous().transpose(-1,-2)
|
||||
N, K = kv_b_proj_weight.shape[0], kv_b_proj_weight.shape[1]
|
||||
|
||||
# 512,1024 => 512,4,256
|
||||
kv_b_proj_weight = kv_b_proj_weight.view(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads,
|
||||
self.qk_nope_head_dim + self.v_head_dim,
|
||||
)
|
||||
|
||||
kv_b_proj_scale = kv_b_proj_scale.view(
|
||||
kv_b_proj_scale.shape[0] * self.kv_lora_rank // N,
|
||||
self.num_heads,
|
||||
kv_b_proj_scale.shape[1] * N // (self.kv_lora_rank * self.num_heads),
|
||||
)
|
||||
|
||||
W_UK, W_UV = kv_b_proj_weight.split(
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
W_UK = W_UK.contiguous()
|
||||
|
||||
scale_0 = kv_b_proj_scale.shape[-1] * self.qk_nope_head_dim // (self.qk_nope_head_dim + self.v_head_dim)
|
||||
scale_1 = kv_b_proj_scale.shape[-1] - scale_0
|
||||
|
||||
W_UK_scale, W_UV_scale = kv_b_proj_scale.split(
|
||||
[scale_0, scale_1], dim=-1)
|
||||
W_UK_scale = W_UK_scale.view(W_UK_scale.shape[0], -1).unsqueeze(-1).contiguous()
|
||||
W_UV_scale = W_UV_scale.view(W_UV_scale.shape[0], -1).unsqueeze(-1)
|
||||
|
||||
# weight: [1536, 768] scale: 12,6
|
||||
q_proj_weight, q_proj_scale = \
|
||||
[t.T for t in get_fp8_layer_weight(self.q_proj)]
|
||||
|
||||
#self.W_Q_QR = q_proj_weight.contiguous().transpose(-2,-1).contiguous().transpose(-2,-1)
|
||||
#self.W_Q_QR_scales = q_proj_scale.reshape(12, 6, 1).repeat(1, 1, 4).reshape(12, -1).contiguous().transpose(-2,-1).contiguous().transpose(-2,-1)
|
||||
|
||||
q_proj_weight = q_proj_weight\
|
||||
.view(-1, self.num_heads, self.qk_head_dim)
|
||||
# w_q[1536, 512] + w_qr[1536, 256]
|
||||
W_Q = q_proj_weight[..., :self.qk_nope_head_dim].flatten(start_dim=1)
|
||||
W_QR = q_proj_weight[..., self.qk_nope_head_dim:]\
|
||||
.flatten(start_dim=1).contiguous()
|
||||
# w_q_scale 12,16 + w_qr_scale 12,8
|
||||
# expand: 12,6(4+2) -> 12,24(16+8)
|
||||
# Q_scale: [s0x4, s1x2, s2x2, s3x4, s4x2, s5x2]
|
||||
repeat_pattern = torch.tensor([4, 2, 2, 4, 2, 2], device=q_proj_scale.device)
|
||||
W_Q_scale = torch.repeat_interleave(q_proj_scale, repeat_pattern, dim=1)
|
||||
# Q_R_scale: [s1x2, s2x2, s4x2, s5x2]
|
||||
selected_indices = [1, 2, 4, 5]
|
||||
repeat_times = 2
|
||||
selected = q_proj_scale[:, selected_indices]
|
||||
W_QR_scale = selected.repeat_interleave(repeat_times, dim=1)
|
||||
|
||||
# temp_WQ_Scale = W_Q_scale.reshape(12, 4, -1).contiguous()
|
||||
# temp_W_QR_scale = W_QR_scale.reshape(12, 4, -1).contiguous()
|
||||
# temp_scale = torch.cat([temp_WQ_Scale, temp_W_QR_scale], dim=2).contiguous().reshape(12, -1).contiguous().transpose(-2,-1).contiguous().transpose(-2,-1)
|
||||
# self.W_Q_QR_scales = temp_scale
|
||||
# print("W_Q_scale:", W_Q_scale.shape)
|
||||
# print("W_QR_scale:", W_QR_scale.shape)
|
||||
# print("temp_scale:", temp_scale.shape)
|
||||
# exit(0)
|
||||
|
||||
# Note: to be vnnl compatible
|
||||
# 1. expand w_uv scale for core split friendly
|
||||
if W_UV.shape[-1] % 4 == 0:
|
||||
W_UV_scale = W_UV_scale.expand((W_UV_scale.shape[0], W_UV_scale.shape[1], 4))
|
||||
# 2. change w_q, w_qr, w_uv weight&scale to K-contiguous (shape unchanged)
|
||||
W_Q = W_Q.transpose(-2,-1).contiguous().transpose(-2,-1)
|
||||
W_Q_scale = W_Q_scale.transpose(-2,-1).contiguous().transpose(-2,-1)
|
||||
W_QR = W_QR.transpose(-2,-1).contiguous().transpose(-2,-1)
|
||||
W_QR_scale = W_QR_scale.transpose(-2,-1).contiguous().transpose(-2,-1)
|
||||
|
||||
W_UV = W_UV.permute(2,1,0).contiguous().permute(2,1,0)
|
||||
W_UV_scale = W_UV_scale.permute(2,1,0).contiguous().permute(2,1,0)
|
||||
|
||||
self.W_Q = W_Q
|
||||
self.W_Q_scales = W_Q_scale
|
||||
|
||||
self.W_QR = W_QR
|
||||
self.W_QR_scales = W_QR_scale
|
||||
|
||||
# temp_Q_scale = self.W_Q_scales.contiguous()
|
||||
# temp_W_QR_scale = self.W_QR_scales.contiguous()
|
||||
# self.W_Q_QR = q_proj_weight.reshape(1536, -1).contiguous().transpose(-2,-1).contiguous().transpose(-2,-1)
|
||||
# self.W_Q_QR_scales = torch.concat([temp_Q_scale,temp_W_QR_scale],dim=1).contiguous().transpose(-2,-1).contiguous().transpose(-2,-1)
|
||||
#self.W_Q_QR = torch.concat([self.W_Q.contiguous(),self.W_QR.contiguous()],dim=1).contiguous().transpose(-2,-1).contiguous().transpose(-2,-1)
|
||||
#self.W_Q_QR_scales = torch.concat([W_Q_scale,W_QR_scale],dim=1).contiguous().transpose(-2,-1).contiguous().transpose(-2,-1)
|
||||
|
||||
self.W_UV = W_UV
|
||||
self.W_UV_scales = W_UV_scale
|
||||
|
||||
self.W_UK = W_UK
|
||||
self.W_UK_scales = W_UK_scale
|
||||
return
|
||||
|
||||
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
|
||||
assert kv_b_proj_weight.shape == (
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
|
||||
f"{kv_b_proj_weight.shape=}, "
|
||||
f"{self.kv_lora_rank=}, "
|
||||
f"{self.num_heads=}, "
|
||||
f"{self.qk_nope_head_dim=}, "
|
||||
f"{self.v_head_dim=}")
|
||||
kv_b_proj_weight = kv_b_proj_weight.view(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads,
|
||||
self.qk_nope_head_dim + self.v_head_dim,
|
||||
)
|
||||
|
||||
W_UK, W_UV = kv_b_proj_weight.split(
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
|
||||
q_proj_weight = get_and_maybe_dequant_weights(self.q_proj).T\
|
||||
.view(-1, self.num_heads, self.qk_head_dim)
|
||||
|
||||
# can be W_Q or W_UQ depending q_lora_rank, the former if
|
||||
# q_lora_rank is None, the latter otherwise. From the Attention backend
|
||||
# perspective though we call these both W_Q and rely on the layer
|
||||
# to pass in the correct matrix
|
||||
W_Q = q_proj_weight[..., :self.qk_nope_head_dim]
|
||||
self.W_QR = q_proj_weight[..., self.qk_nope_head_dim:]\
|
||||
.flatten(start_dim=1).contiguous()
|
||||
|
||||
# W_QR is small so for simplicity we dont bother requantizing it
|
||||
self.W_QR = self.W_QR.to(act_dtype)
|
||||
|
||||
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
|
||||
assert False, "please set VLLM_MLA_PERFORM_MATRIX_ABSORPTION=0"
|
||||
requantization_enabled = not envs.VLLM_MLA_DISABLE_REQUANTIZATION
|
||||
if is_fp8(weight_dtype) and requantization_enabled:
|
||||
# This assumes it wise to requantize using the same group shapes
|
||||
# (i.e. strategy, per-tensor, per-channel, block etc.) that the
|
||||
# weights were originally quantized
|
||||
requant_input_group_shape, requant_weight_group_shape = \
|
||||
get_scale_group_shapes_for_fp8(self.q_proj)
|
||||
assert (requant_input_group_shape, requant_weight_group_shape)\
|
||||
== get_scale_group_shapes_for_fp8(self.kv_b_proj)
|
||||
assert (requant_input_group_shape, requant_weight_group_shape)\
|
||||
== get_scale_group_shapes_for_fp8(self.o_proj)
|
||||
self.reqaunt_input_group_shape = requant_input_group_shape
|
||||
self.reqaunt_weight_group_shape = requant_weight_group_shape
|
||||
|
||||
#
|
||||
# Perform matrix-absorption following
|
||||
# https://github.com/flashinfer-ai/flashinfer/pull/551
|
||||
# for decode, as a result we end up with absorbed weights for decode
|
||||
# and another copy of raw weights for prefill.
|
||||
#
|
||||
self.W_UK, self.W_UV = kv_b_proj_weight.split(
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
# We absorb `W_UK` into `W_Q` resulting in either W_Q_UK or W_UQ_UK
|
||||
# depending q_lora_rank, the former if q_lora_rank is None, the
|
||||
# latter otherwise
|
||||
# basically if q_lora_rank is none we are absorbing into q_proj
|
||||
# instead of UQ
|
||||
W_Q_UK = torch.einsum("qnd,lnd -> qnl", W_Q, W_UK)\
|
||||
.flatten(start_dim=1).contiguous()
|
||||
|
||||
if is_fp8(weight_dtype) and requantization_enabled:
|
||||
W_Q_UK, W_Q_UK_scales = scaled_quantize(
|
||||
W_Q_UK,
|
||||
self.reqaunt_weight_group_shape,
|
||||
quant_dtype=current_platform_fp8_dtype)
|
||||
# For FP8 save the transpose so we can use
|
||||
# `apply_w8a8_block_fp8_linear` directly
|
||||
self.W_Q_UK = W_Q_UK.T.contiguous()
|
||||
self.W_Q_UK_scales = W_Q_UK_scales.T.contiguous()
|
||||
else:
|
||||
self.W_Q_UK = W_Q_UK.to(act_dtype)
|
||||
|
||||
W_O = get_and_maybe_dequant_weights(self.o_proj)\
|
||||
.view(-1, self.num_heads, self.v_head_dim)
|
||||
W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\
|
||||
.flatten(start_dim=0, end_dim=1).contiguous()
|
||||
|
||||
if is_fp8(weight_dtype) and requantization_enabled:
|
||||
W_UV_O, W_UV_O_scales = scaled_quantize(
|
||||
W_UV_O,
|
||||
self.reqaunt_weight_group_shape,
|
||||
quant_dtype=current_platform_fp8_dtype)
|
||||
# For FP8 save the transpose so we can use
|
||||
# `apply_w8a8_block_fp8_linear` directly
|
||||
self.W_UV_O = W_UV_O.T.contiguous()
|
||||
self.W_UV_O_scales = W_UV_O_scales.T.contiguous()
|
||||
else:
|
||||
self.W_UV_O = W_UV_O.to(act_dtype)
|
||||
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
else:
|
||||
# print('W_UV', W_UV.dtype) #float32
|
||||
#if is_fp8(weight_dtype):
|
||||
# raise NotImplementedError(
|
||||
# "Currently fp8 requires matrix absorption")
|
||||
# self.W_UV = W_UV
|
||||
# self.W_UK = W_UK
|
||||
self.W_UV = W_UV.to(act_dtype) # fp32 to bfp16
|
||||
self.W_UK = W_UK.to(act_dtype)
|
||||
W_Q = W_Q.to(act_dtype)
|
||||
self.W_Q = W_Q.flatten(start_dim=1)
|
||||
726
vllm_vacc/vllm/attention/backends/vacc_attn.py
Normal file
726
vllm_vacc/vllm/attention/backends/vacc_attn.py
Normal file
@@ -0,0 +1,726 @@
|
||||
""" Attention layer with torch scaled_dot_product_attention
|
||||
and PagedAttention."""
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType)
|
||||
# from vllm.attention.backends.utils import CommonAttentionState
|
||||
# from vllm.attention.ops.paged_attn import PagedAttentionMetadata
|
||||
|
||||
from vllm_vacc.vllm.attention.ops.vacc_paged_attn import VaccPagedAttention as PagedAttention
|
||||
|
||||
# from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
# AttentionLayer,
|
||||
# AttentionMetadata,
|
||||
# AttentionMetadataBuilder,
|
||||
# AttentionType)
|
||||
# from vllm.attention.backends.utils import CommonAttentionState
|
||||
# from vllm.attention.ops.ipex_attn import PagedAttention
|
||||
from vllm.attention.ops.paged_attn import PagedAttentionMetadata
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import make_tensor_with_pad
|
||||
from vllm_vacc.vllm.v1.worker.vacc_model_runner import ModelInputForVACCBuilder
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
import os
|
||||
from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size
|
||||
|
||||
class VACCAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TORCH_VACC"
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["VACCAttentionBackendImpl"]:
|
||||
return VACCAttentionBackendImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||
return VACCAttentionMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> Type["VACCMetadataBuilder"]:
|
||||
return VACCMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
|
||||
num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
PagedAttention.copy_blocks(kv_caches, src_to_dists)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VACCAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
"""Metadata for VACCAttentionMetadata.
|
||||
"""
|
||||
# Currently, input sequences can only contain all prompts
|
||||
# or all decoding. True if all sequences are prompts.
|
||||
chunked_prefill: bool
|
||||
seq_lens: Optional[List[int]] = None # For non-chunked prefill
|
||||
# For chunked prefill only
|
||||
max_query_len: Optional[int] = None
|
||||
max_kv_len: Optional[int] = None
|
||||
query_start_loc: Optional[torch.Tensor] = None
|
||||
kv_start_loc: Optional[torch.Tensor] = None
|
||||
prefill_block_tables: Optional[torch.Tensor] = None
|
||||
|
||||
# Begin encoder attn & enc/dec cross-attn fields...
|
||||
# Encoder sequence lengths representation
|
||||
encoder_seq_lens: Optional[List[int]] = None
|
||||
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
|
||||
|
||||
# Maximum sequence length among encoder sequences
|
||||
max_encoder_seq_len: Optional[int] = None
|
||||
|
||||
# Number of tokens input to encoder
|
||||
num_encoder_tokens: Optional[int] = None
|
||||
|
||||
# Cross-attention memory-mapping data structures: slot mapping
|
||||
# and block tables
|
||||
cross_slot_mapping: Optional[torch.Tensor] = None
|
||||
cross_block_tables: Optional[torch.Tensor] = None
|
||||
|
||||
def __post_init__(self):
|
||||
# Set during the execution of the first attention op.
|
||||
# It is a list because it is needed to set per prompt
|
||||
# when alibi slopes is used. It is because of the limitation
|
||||
# from xformer API.
|
||||
# will not appear in the __repr__ and __init__
|
||||
self.attn_bias: Optional[List[torch.Tensor]] = None
|
||||
self.encoder_attn_bias: Optional[List[torch.Tensor]] = None
|
||||
self.cross_attn_bias: Optional[List[torch.Tensor]] = None
|
||||
|
||||
@property
|
||||
def is_all_encoder_attn_metadata_set(self):
|
||||
'''
|
||||
All attention metadata required for encoder attention is set.
|
||||
'''
|
||||
return ((self.encoder_seq_lens is not None)
|
||||
and (self.encoder_seq_lens_tensor is not None)
|
||||
and (self.max_encoder_seq_len is not None))
|
||||
|
||||
@property
|
||||
def is_all_cross_attn_metadata_set(self):
|
||||
'''
|
||||
All attention metadata required for enc/dec cross-attention is set.
|
||||
|
||||
Superset of encoder attention required metadata.
|
||||
'''
|
||||
return (self.is_all_encoder_attn_metadata_set
|
||||
and (self.cross_slot_mapping is not None)
|
||||
and (self.cross_block_tables is not None))
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["VACCAttentionMetadata"]:
|
||||
# Currently chunked prefill is not supported
|
||||
if self.num_prefill_tokens == 0:
|
||||
return None
|
||||
return self
|
||||
|
||||
@property
|
||||
def decode_metadata(self) -> Optional["VACCAttentionMetadata"]:
|
||||
# Currently chunked prefill is not supported
|
||||
if self.num_decode_tokens == 0:
|
||||
return None
|
||||
return self
|
||||
|
||||
def get_seq_lens(
|
||||
self,
|
||||
attn_type: AttentionType,
|
||||
):
|
||||
'''
|
||||
Extract appropriate sequence lengths from attention metadata
|
||||
according to attention type.
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_metadata: Attention metadata structure associated with attention
|
||||
* attn_type: encoder attention, decoder self-attention,
|
||||
encoder/decoder cross-attention
|
||||
|
||||
Returns:
|
||||
* Appropriate sequence lengths tensor for query
|
||||
* Appropriate sequence lengths tensor for key & value
|
||||
'''
|
||||
|
||||
if (attn_type == AttentionType.DECODER
|
||||
or attn_type == AttentionType.ENCODER_ONLY):
|
||||
seq_lens_q = self.seq_lens
|
||||
seq_lens_kv = self.seq_lens
|
||||
elif attn_type == AttentionType.ENCODER:
|
||||
seq_lens_q = self.encoder_seq_lens
|
||||
seq_lens_kv = self.encoder_seq_lens
|
||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||
seq_lens_q = self.seq_lens
|
||||
seq_lens_kv = self.encoder_seq_lens
|
||||
else:
|
||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||
return seq_lens_q, seq_lens_kv
|
||||
|
||||
def get_attn_bias(
|
||||
self,
|
||||
attn_type: AttentionType,
|
||||
) -> Optional[List[torch.Tensor]]:
|
||||
'''
|
||||
Extract appropriate attention bias from attention metadata
|
||||
according to attention type.
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_metadata: Attention metadata structure associated with attention
|
||||
* attn_type: encoder attention, decoder self-attention,
|
||||
encoder/decoder cross-attention
|
||||
|
||||
Returns:
|
||||
* Appropriate attention bias value given the attention type
|
||||
'''
|
||||
|
||||
if (attn_type == AttentionType.DECODER
|
||||
or attn_type == AttentionType.ENCODER_ONLY):
|
||||
return self.attn_bias
|
||||
elif attn_type == AttentionType.ENCODER:
|
||||
return self.encoder_attn_bias
|
||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||
return self.cross_attn_bias
|
||||
else:
|
||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||
|
||||
def set_attn_bias(
|
||||
self,
|
||||
attn_bias: List[torch.Tensor],
|
||||
attn_type: AttentionType,
|
||||
) -> None:
|
||||
'''
|
||||
Update appropriate attention bias field of attention metadata,
|
||||
according to attention type.
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_metadata: Attention metadata structure associated with attention
|
||||
* attn_bias: The desired attention bias value
|
||||
* attn_type: encoder attention, decoder self-attention,
|
||||
encoder/decoder cross-attention
|
||||
'''
|
||||
|
||||
if (attn_type == AttentionType.DECODER
|
||||
or attn_type == AttentionType.ENCODER_ONLY):
|
||||
self.attn_bias = attn_bias
|
||||
elif attn_type == AttentionType.ENCODER:
|
||||
self.encoder_attn_bias = attn_bias
|
||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||
self.cross_attn_bias = attn_bias
|
||||
else:
|
||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||
|
||||
def get_seq_len_block_table_args(
|
||||
self,
|
||||
attn_type: str,
|
||||
) -> tuple:
|
||||
'''
|
||||
The particular choice of sequence-length- and block-table-related
|
||||
attributes which should be extracted from attn_metadata is dependent
|
||||
on the type of attention operation.
|
||||
|
||||
Decoder attn -> select entirely decoder self-attention-related fields
|
||||
Encoder/decoder cross-attn -> select encoder sequence lengths &
|
||||
cross-attn block-tables fields
|
||||
Encoder attn -> select encoder sequence lengths fields & no block tables
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_metadata: Attention metadata structure associated with attention
|
||||
* is_prompt: True if prefill, False otherwise
|
||||
* attn_type: encoder attention, decoder self-attention,
|
||||
encoder/decoder cross-attention
|
||||
|
||||
Returns:
|
||||
|
||||
* Appropriate sequence-lengths tensor
|
||||
* Appropriate max sequence-length scalar
|
||||
* Appropriate block tables (or None)
|
||||
'''
|
||||
|
||||
if (attn_type == AttentionType.DECODER
|
||||
or attn_type == AttentionType.ENCODER_ONLY):
|
||||
# Decoder self-attention
|
||||
# Choose max_seq_len based on whether we are in prompt_run
|
||||
return (self.seq_lens_tensor, self.max_decode_seq_len,
|
||||
self.block_tables)
|
||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||
# Enc/dec cross-attention KVs match encoder sequence length;
|
||||
# cross-attention utilizes special "cross" block tables
|
||||
return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len,
|
||||
self.cross_block_tables)
|
||||
elif attn_type == AttentionType.ENCODER:
|
||||
# No block tables associated with encoder attention
|
||||
return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len,
|
||||
None)
|
||||
else:
|
||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||
|
||||
class VACCMetadataBuilder(AttentionMetadataBuilder[VACCAttentionMetadata]):
|
||||
|
||||
def __init__(self, input_builder: ModelInputForVACCBuilder) -> None:
|
||||
self.chunked_prefill = input_builder.chunked_prefill
|
||||
self.input_builder = input_builder
|
||||
|
||||
def prepare(self):
|
||||
self.input_data = self.input_builder.input_data
|
||||
|
||||
def build(self, seq_lens: List[int], query_lens: List[int],
|
||||
cuda_graph_pad_size: int, batch_size: int) -> VACCAttentionMetadata:
|
||||
input_data = self.input_data
|
||||
prefill_seq_lens = seq_lens[0:input_data.num_prefills]
|
||||
prefill_query_lens = query_lens[0:input_data.num_prefills]
|
||||
slot_mapping = torch.tensor(input_data.slot_mapping,
|
||||
dtype=torch.int32,
|
||||
device=self.input_builder.device)
|
||||
|
||||
# For chunked-prefill
|
||||
if self.chunked_prefill and input_data.num_prefill_tokens != 0:
|
||||
prefill_block_tables = make_tensor_with_pad(
|
||||
self.input_data.prefill_block_tables,
|
||||
pad=0,
|
||||
dtype=torch.int32,
|
||||
device=self.input_builder.device,
|
||||
)
|
||||
query_lens_tensor = torch.tensor(prefill_query_lens,
|
||||
dtype=torch.int32,
|
||||
device=self.input_builder.device)
|
||||
kv_lens_tensor = torch.tensor(prefill_seq_lens,
|
||||
dtype=torch.int32,
|
||||
device=self.input_builder.device)
|
||||
query_start_loc = torch.zeros(input_data.num_prefills + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.input_builder.device)
|
||||
kv_start_loc = torch.zeros(input_data.num_prefills + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.input_builder.device)
|
||||
torch.cumsum(query_lens_tensor,
|
||||
dim=0,
|
||||
dtype=torch.int32,
|
||||
out=query_start_loc[1:])
|
||||
torch.cumsum(kv_lens_tensor,
|
||||
dim=0,
|
||||
dtype=torch.int32,
|
||||
out=kv_start_loc[1:])
|
||||
max_query_len = max(prefill_query_lens)
|
||||
max_kv_len = max(prefill_seq_lens)
|
||||
else:
|
||||
prefill_block_tables = None
|
||||
query_start_loc = None
|
||||
kv_start_loc = None
|
||||
max_query_len = None
|
||||
max_kv_len = None
|
||||
|
||||
# For paged attention
|
||||
if input_data.num_decode_tokens != 0:
|
||||
seq_lens_tensor = torch.tensor(
|
||||
input_data.seq_lens[input_data.num_prefills:],
|
||||
dtype=torch.int32,
|
||||
device=self.input_builder.device,
|
||||
)
|
||||
block_tables = make_tensor_with_pad(
|
||||
self.input_data.decode_block_tables,
|
||||
pad=0,
|
||||
dtype=torch.int32,
|
||||
device=self.input_builder.device,
|
||||
)
|
||||
# lowest_dim_size = block_tables.size(-1)
|
||||
# if lowest_dim_size < 1024:
|
||||
# padding_amount = 1024 - lowest_dim_size
|
||||
# padding = torch.zeros(*block_tables.size()[:-1], padding_amount, dtype=block_tables.dtype, device=block_tables.device)
|
||||
# block_tables = torch.cat((block_tables, padding), dim=-1)
|
||||
else:
|
||||
block_tables = torch.tensor([])
|
||||
seq_lens_tensor = torch.tensor(
|
||||
input_data.seq_lens[:input_data.num_prefills],
|
||||
dtype=torch.int32,
|
||||
device=self.input_builder.device,
|
||||
)
|
||||
|
||||
# For multi-modal models
|
||||
placeholder_index_maps = None
|
||||
if len(input_data.multi_modal_inputs_list) != 0:
|
||||
placeholder_index_maps = {
|
||||
modality: placeholder_map.index_map()
|
||||
for modality, placeholder_map in
|
||||
input_data.multi_modal_placeholder_maps.items()
|
||||
}
|
||||
|
||||
attn_metadata = VACCAttentionMetadata(
|
||||
chunked_prefill=self.chunked_prefill,
|
||||
seq_lens=seq_lens, #prefill_seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=max_query_len,
|
||||
max_kv_len=max_kv_len,
|
||||
query_start_loc=query_start_loc,
|
||||
kv_start_loc=kv_start_loc,
|
||||
max_decode_seq_len=None,
|
||||
num_prefills=input_data.num_prefills,
|
||||
num_prefill_tokens=input_data.num_prefill_tokens,
|
||||
num_decode_tokens=input_data.num_decode_tokens,
|
||||
block_tables=block_tables,
|
||||
prefill_block_tables=prefill_block_tables,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
||||
enable_kv_scales_calculation=False,
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
class VACCAttentionBackendImpl(AttentionImpl[VACCAttentionMetadata]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
) -> None:
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"Torch SPDA does not support block-sparse attention.")
|
||||
if logits_soft_cap is not None:
|
||||
logger.warning_once("Torch SPDA does not support logits soft cap. "
|
||||
"Outputs may be slightly off.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
self.sliding_window = sliding_window
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
self.need_mask = (self.alibi_slopes is not None
|
||||
or self.sliding_window is not None)
|
||||
|
||||
supported_head_sizes = PagedAttention.get_supported_head_sizes()
|
||||
if head_size not in supported_head_sizes:
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by PagedAttention. "
|
||||
f"Supported head sizes are: {supported_head_sizes}.")
|
||||
if kv_cache_dtype != "auto":
|
||||
raise NotImplementedError(
|
||||
"Torch SDPA backend does not support FP8 KV cache. "
|
||||
"Please use xFormers backend instead.")
|
||||
self.attn_type = attn_type
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: VACCAttentionMetadata, # type: ignore
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with torch SDPA and PagedAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
|
||||
NOTE: kv_cache will be an empty tensor with shape [0]
|
||||
for profiling run.
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
attn_type = self.attn_type
|
||||
if (attn_type == AttentionType.ENCODER
|
||||
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
|
||||
raise AttributeError("Encoder attention requires setting "
|
||||
"encoder metadata attributes.")
|
||||
elif (attn_type == AttentionType.ENCODER_DECODER
|
||||
and (not attn_metadata.is_all_cross_attn_metadata_set)):
|
||||
raise AttributeError("Encoder/decoder cross-attention "
|
||||
"requires setting cross-attention "
|
||||
"metadata attributes.")
|
||||
|
||||
# Reshape the query, key, and value tensors.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
if key is not None:
|
||||
assert value is not None
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
else:
|
||||
assert value is None
|
||||
if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0):
|
||||
# KV-cache during decoder-self- or
|
||||
# encoder-decoder-cross-attention, but not
|
||||
# during encoder attention.
|
||||
#
|
||||
# Even if there are no new key/value pairs to cache,
|
||||
# we still need to break out key_cache and value_cache
|
||||
# i.e. for later use by paged attention
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size)
|
||||
|
||||
|
||||
if (key is not None) and (value is not None):
|
||||
if attn_type == AttentionType.ENCODER_DECODER:
|
||||
# Update cross-attention KV cache (prefill-only)
|
||||
# During cross-attention decode, key & value will be None,
|
||||
# preventing this IF-statement branch from running
|
||||
updated_slot_mapping = attn_metadata.cross_slot_mapping
|
||||
else:
|
||||
# Update self-attention KV cache (prefill/decode)
|
||||
updated_slot_mapping = attn_metadata.slot_mapping
|
||||
PagedAttention.write_to_paged_cache(key, value, key_cache,
|
||||
value_cache,
|
||||
updated_slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale, layer._v_scale)
|
||||
|
||||
if attn_type != AttentionType.ENCODER:
|
||||
# Decoder self-attention supports chunked prefill.
|
||||
# Encoder/decoder cross-attention requires no chunked
|
||||
# prefill (100% prefill or 100% decode tokens, no mix)
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
else:
|
||||
# Encoder attention - chunked prefill is not applicable;
|
||||
# derive token-count from query shape & and treat them
|
||||
# as 100% prefill tokens
|
||||
assert attn_metadata.num_encoder_tokens is not None
|
||||
num_prefill_tokens = attn_metadata.num_encoder_tokens
|
||||
num_decode_tokens = 0
|
||||
|
||||
if attn_type == AttentionType.DECODER:
|
||||
# Only enforce this shape-constraint for decoder
|
||||
# self-attention
|
||||
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||
output = torch.empty_like(query)
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
assert attn_metadata.seq_lens is not None
|
||||
if (kv_cache.numel() == 0
|
||||
or prefill_meta.block_tables.numel() == 0):
|
||||
self._run_vacc_forward(
|
||||
output,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
prefill_meta,
|
||||
attn_type=attn_type)
|
||||
else:
|
||||
# prefix-enabled attention
|
||||
assert not self.need_mask
|
||||
import intel_extension_for_pytorch.llm.modules as ipex_modules
|
||||
output = torch.empty_like(query)
|
||||
ipex_modules.PagedAttention.flash_attn_varlen_func(
|
||||
output[:prefill_meta.num_prefill_tokens, :, :],
|
||||
query[:prefill_meta.num_prefill_tokens, :, :],
|
||||
key_cache,
|
||||
value_cache,
|
||||
prefill_meta.query_start_loc,
|
||||
prefill_meta.kv_start_loc,
|
||||
prefill_meta.max_query_len,
|
||||
prefill_meta.max_kv_len,
|
||||
self.scale,
|
||||
True,
|
||||
prefill_meta.prefill_block_tables,
|
||||
self.alibi_slopes,
|
||||
)
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
assert attn_type != AttentionType.ENCODER_ONLY, (
|
||||
"Encoder-only models should not have decode metadata.")
|
||||
# Decoding run.
|
||||
# (
|
||||
# seq_lens_arg,
|
||||
# max_seq_len_arg,
|
||||
# block_tables_arg,
|
||||
# ) = decode_meta.get_seq_len_block_table_args(attn_type)
|
||||
|
||||
# Note:
|
||||
# decode attention still use SDPA method
|
||||
# reshape k/v_cache to (num_block_grp, block_grp_size, head, hidden_size)
|
||||
k_cache = key_cache.view(-1, env_blk_grp_size, key_cache.shape[2], key_cache.shape[3])
|
||||
v_cache = value_cache.view(-1, env_blk_grp_size, value_cache.shape[2], value_cache.shape[3])
|
||||
block_per_group = env_blk_grp_size // 16
|
||||
# convert block_tables to 8K group index
|
||||
block_tables = (decode_meta.block_tables // block_per_group).to(torch.int32)
|
||||
attn_outs = []
|
||||
for i in range(decode_meta.seq_lens_tensor.shape[0]):
|
||||
seq_len = decode_meta.seq_lens_tensor[i]
|
||||
k_slices = k_cache[block_tables[i], ...]
|
||||
k = \
|
||||
torch.cat([k_slices[i, ...] for i in range(len(block_tables[i]))], dim=0)[:seq_len]
|
||||
v_slices = v_cache[block_tables[i], ...]
|
||||
v = \
|
||||
torch.cat([v_slices[i, ...] for i in range(len(block_tables[i]))], dim=0)[:seq_len]
|
||||
q = query[i : i + 1, ...]
|
||||
attn_out = torch.vacc.scaled_dot_product_attention(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
attn_mask=None,
|
||||
dropout_p=0,
|
||||
is_causal=False,
|
||||
is_train=False,
|
||||
recompute=False,
|
||||
flash_attention=False,
|
||||
sm_scale=self.scale,
|
||||
)
|
||||
attn_outs.append(attn_out)
|
||||
output = torch.cat(attn_outs, dim=0)
|
||||
# '''
|
||||
|
||||
# PagedAttention.forward_decode(
|
||||
# output[attn_metadata.num_prefill_tokens:, :, :],
|
||||
# query[attn_metadata.num_prefill_tokens:, :, :],
|
||||
# key_cache,
|
||||
# value_cache,
|
||||
# block_tables_arg,
|
||||
# seq_lens_arg,
|
||||
# max_seq_len_arg,
|
||||
# self.kv_cache_dtype,
|
||||
# self.num_kv_heads,
|
||||
# self.scale,
|
||||
# self.alibi_slopes,
|
||||
# layer._k_scale,
|
||||
# layer._v_scale,
|
||||
# )
|
||||
# Reshape the output tensor.
|
||||
return output.view(-1, self.num_heads * self.head_size)
|
||||
|
||||
def _run_vacc_forward(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_metadata: VACCAttentionMetadata,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
):
|
||||
# if self.num_kv_heads != self.num_heads:
|
||||
# key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
|
||||
# value = value.repeat_interleave(self.num_queries_per_kv, dim=1)
|
||||
attn_masks = attn_metadata.get_attn_bias(attn_type)
|
||||
if attn_masks is None:
|
||||
if self.alibi_slopes is not None:
|
||||
attn_masks = _make_alibi_bias(
|
||||
self.alibi_slopes, query.dtype,
|
||||
attn_metadata.seq_lens) # type: ignore
|
||||
elif self.sliding_window is not None:
|
||||
assert attn_metadata.seq_lens is not None
|
||||
attn_masks = _make_sliding_window_bias(
|
||||
attn_metadata.seq_lens, self.sliding_window,
|
||||
query.dtype) # type: ignore
|
||||
else:
|
||||
seq_lens, _ = attn_metadata.get_seq_lens(attn_type)
|
||||
attn_masks = [None] * len(seq_lens)
|
||||
attn_metadata.set_attn_bias(attn_masks, attn_type)
|
||||
|
||||
causal_attn = (attn_type == AttentionType.DECODER)
|
||||
|
||||
seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type)
|
||||
start_q, start_kv = 0, 0
|
||||
for seq_len_q, seq_len_kv, mask in zip(seq_lens_q, seq_lens_kv,
|
||||
attn_masks):
|
||||
end_q = start_q + seq_len_q
|
||||
end_kv = start_kv + seq_len_kv
|
||||
sub_out=torch.vacc.scaled_dot_product_attention(
|
||||
query[start_q:end_q,:, :],
|
||||
key[start_kv:end_kv,:, :],
|
||||
value[start_kv:end_kv,:, :].contiguous(),
|
||||
attn_mask=None,
|
||||
dropout_p=0.0,
|
||||
is_causal=True, #causal_attn and not self.need_mask,
|
||||
is_train=False,
|
||||
recompute=False,
|
||||
flash_attention=False,
|
||||
sm_scale=self.scale)
|
||||
output[ start_q:end_q,:, :] = sub_out
|
||||
start_q, start_kv = end_q, end_kv
|
||||
return output
|
||||
|
||||
|
||||
def _make_alibi_bias(
|
||||
alibi_slopes: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
seq_lens: List[int],
|
||||
) -> List[torch.Tensor]:
|
||||
attn_biases: List[torch.Tensor] = []
|
||||
for seq_len in seq_lens:
|
||||
bias = torch.arange(seq_len, dtype=dtype)
|
||||
# NOTE(zhuohan): HF uses
|
||||
# `bias = bias[None, :].repeat(seq_len, 1)`
|
||||
# here. We find that both biases give the same results, but
|
||||
# the bias below more accurately follows the original ALiBi
|
||||
# paper.
|
||||
bias = bias[None, :] - bias[:, None]
|
||||
|
||||
num_heads = alibi_slopes.shape[0]
|
||||
bias = bias[None, :].repeat((num_heads, 1, 1))
|
||||
bias.mul_(alibi_slopes[:, None, None]).unsqueeze_(0)
|
||||
inf_mask = torch.empty(
|
||||
(1, seq_len, seq_len),
|
||||
dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1)
|
||||
attn_biases.append((bias + inf_mask).to(dtype))
|
||||
|
||||
return attn_biases
|
||||
|
||||
|
||||
def _make_sliding_window_bias(
|
||||
seq_lens: List[int],
|
||||
window_size: Optional[int],
|
||||
dtype: torch.dtype,
|
||||
) -> List[torch.Tensor]:
|
||||
attn_biases: List[torch.Tensor] = []
|
||||
for seq_len in seq_lens:
|
||||
tensor = torch.full(
|
||||
(1, seq_len, seq_len),
|
||||
dtype=dtype,
|
||||
fill_value=1,
|
||||
)
|
||||
shift = 0
|
||||
mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore
|
||||
if window_size is not None:
|
||||
mask = torch.triu(mask, diagonal=shift - window_size + 1)
|
||||
mask = torch.log(mask)
|
||||
attn_biases.append(mask.to(dtype))
|
||||
|
||||
return attn_biases
|
||||
847
vllm_vacc/vllm/attention/backends/vacc_mla.py
Normal file
847
vllm_vacc/vllm/attention/backends/vacc_mla.py
Normal file
@@ -0,0 +1,847 @@
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from itertools import accumulate
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
from vllm.multimodal import MultiModalPlaceholderMap
|
||||
|
||||
try:
|
||||
from flashinfer import BatchDecodeMlaWithPagedKVCacheWrapper
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
|
||||
except ImportError:
|
||||
BatchDecodeMlaWithPagedKVCacheWrapper = None
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionState, AttentionType)
|
||||
from vllm.attention.backends.mla.common import MLACommonImpl, MLACommonMetadata
|
||||
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
|
||||
compute_slot_mapping_start_idx,
|
||||
is_block_tables_empty)
|
||||
|
||||
#from vllm.attention.ops.paged_attn import PagedAttention
|
||||
from vllm_vacc.vllm.attention.ops.vacc_paged_attn import VaccPagedAttention as PagedAttention
|
||||
# from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
|
||||
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||
|
||||
# import time, os
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm_vacc.vllm.worker.vacc_model_runner import (ModelInputForVACCBuilder,
|
||||
ModelInputForVACCWithSamplingMetadata)
|
||||
|
||||
|
||||
class VACCMLABackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TORCH_VACC"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["VACCMLAImpl"]:
|
||||
return VACCMLAImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||
return VACCMLAMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> Type["VACCMLAMetadataBuilder"]:
|
||||
return VACCMLAMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["VACCMLAState"]:
|
||||
return VACCMLAState
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int, # assumed to be 1 for MLA
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (num_blocks, block_size, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
PagedAttention.copy_blocks(kv_caches, src_to_dists)
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> List[int]:
|
||||
return [576]
|
||||
|
||||
|
||||
class VACCMLAState(AttentionState):
|
||||
|
||||
def __init__(self, runner):
|
||||
self.runner = runner
|
||||
self._is_graph_capturing = False
|
||||
|
||||
@contextmanager
|
||||
def graph_capture(self, max_batch_size: int):
|
||||
self._is_graph_capturing = True
|
||||
|
||||
self._graph_slot_mapping = torch.full((max_batch_size, ),
|
||||
PAD_SLOT_ID,
|
||||
dtype=torch.long,
|
||||
device=self.runner.device)
|
||||
self._graph_seq_lens = torch.ones(max_batch_size,
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
self._graph_block_tables = torch.from_numpy(
|
||||
self.runner.graph_block_tables).to(device=self.runner.device)
|
||||
|
||||
self._positions = torch.zeros((max_batch_size, ),
|
||||
dtype=torch.long,
|
||||
device=self.runner.device)
|
||||
|
||||
yield
|
||||
|
||||
self._is_graph_capturing = False
|
||||
del self._graph_slot_mapping
|
||||
del self._graph_seq_lens
|
||||
del self._graph_block_tables
|
||||
del self._positions
|
||||
|
||||
def graph_clone(self, batch_size: int):
|
||||
assert self._is_graph_capturing
|
||||
return self.__class__(self.runner)
|
||||
|
||||
def graph_capture_get_metadata_for_batch(
|
||||
self, batch_size: int, is_encoder_decoder_model: bool = False):
|
||||
assert self._is_graph_capturing
|
||||
|
||||
attn_metadata = self.runner.attn_backend.make_metadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=batch_size,
|
||||
slot_mapping=self._graph_slot_mapping[:batch_size],
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=True,
|
||||
seq_lens=None,
|
||||
seq_lens_tensor=self._graph_seq_lens[:batch_size],
|
||||
# max_query_len=1,
|
||||
# max_decode_query_len=1,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.runner.max_seq_len_to_capture,
|
||||
query_start_loc=None,
|
||||
seq_start_loc=None,
|
||||
context_lens_tensor=None,
|
||||
block_tables=self._graph_block_tables[:batch_size],
|
||||
use_cuda_graph=True,
|
||||
input_positions=self._positions[:batch_size],
|
||||
head_dim=self.runner.model_config.get_head_size())
|
||||
|
||||
if is_encoder_decoder_model:
|
||||
raise NotImplementedError(
|
||||
"VACCMLAState does not support encoder/decoder yet")
|
||||
|
||||
return attn_metadata
|
||||
|
||||
def get_graph_input_buffers(self,
|
||||
attn_metadata,
|
||||
is_encoder_decoder_model: bool = False):
|
||||
input_buffers = {
|
||||
"slot_mapping": attn_metadata.slot_mapping,
|
||||
"seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
|
||||
"block_tables": attn_metadata.decode_metadata.block_tables,
|
||||
"input_positions": attn_metadata.decode_metadata.input_positions,
|
||||
}
|
||||
if is_encoder_decoder_model:
|
||||
raise NotImplementedError(
|
||||
"VACCMLAState does not support encoder/decoder yet")
|
||||
|
||||
return input_buffers
|
||||
|
||||
def prepare_graph_input_buffers(self,
|
||||
input_buffers,
|
||||
attn_metadata,
|
||||
is_encoder_decoder_model: bool = False):
|
||||
input_positions = attn_metadata.input_positions
|
||||
num_positions = input_positions.shape[0]
|
||||
input_buffers["seq_lens_tensor"].copy_(
|
||||
attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
|
||||
input_buffers["block_tables"].copy_(
|
||||
attn_metadata.decode_metadata.block_tables, non_blocking=True)
|
||||
# CUDA graph buffer is padded so only perform a partial copy based on
|
||||
# num_positions
|
||||
input_buffers["input_positions"][:num_positions].copy_(
|
||||
input_positions, non_blocking=True)
|
||||
if is_encoder_decoder_model:
|
||||
raise NotImplementedError(
|
||||
"VACCMLAState does not support encoder/decoder yet")
|
||||
|
||||
def begin_forward(self, model_input):
|
||||
return
|
||||
|
||||
|
||||
@dataclass
|
||||
class VACCMLAMetadata(MLACommonMetadata):
|
||||
"""Metadata for VACCMLAMetadata.
|
||||
|
||||
NOTE: Any python object stored here is not updated when it is
|
||||
cuda-graph replayed. If you have values that need to be changed
|
||||
dynamically, it should be stored in tensor. The tensor has to be
|
||||
updated from `CUDAGraphRunner.forward` API.
|
||||
"""
|
||||
# (batch_size,). The sequence length per sequence. Sequence length means
|
||||
# the computed tokens + new tokens None if it is a decoding.
|
||||
seq_lens: Optional[List[int]]
|
||||
# seq_lens stored as a tensor.
|
||||
seq_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ---------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
# Maximum sequence length among prefill batch. 0 if there are decoding
|
||||
# requests only.
|
||||
max_prefill_seq_len: int
|
||||
# Maximum sequence length among decode batch. 0 if there are prefill
|
||||
# requests only.
|
||||
max_decode_seq_len: int
|
||||
# (batch_size,) A tensor of context lengths (tokens that are computed
|
||||
# so far).
|
||||
context_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# (batch_size, max_blocks_per_seq).
|
||||
# Block addresses per sequence. (Seq id -> list of physical block)
|
||||
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
|
||||
# in the kv cache. Each block can contain up to block_size tokens.
|
||||
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
|
||||
# captured.
|
||||
block_tables: Optional[torch.Tensor]
|
||||
|
||||
# Whether or not if cuda graph is enabled.
|
||||
# Cuda-graph is currently enabled for decoding only.
|
||||
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
||||
|
||||
use_cuda_graph: bool
|
||||
|
||||
# Maximum query length in the batch.
|
||||
max_query_len: Optional[int] = None
|
||||
input_positions: Optional[torch.Tensor] = None
|
||||
# Max number of query tokens among request in the batch.
|
||||
max_decode_query_len: Optional[int] = None
|
||||
|
||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||
# the batch, used to index into subquery. E.g., if the subquery length
|
||||
# is [4, 6], it is [0, 4, 10].
|
||||
query_start_loc: Optional[torch.Tensor] = None
|
||||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||
# [4, 6], it is [0, 4, 10].
|
||||
seq_start_loc: Optional[torch.Tensor] = None
|
||||
|
||||
_cached_prefill_metadata: Optional["VACCMLAMetadata"] = None
|
||||
_cached_decode_metadata: Optional["VACCMLAMetadata"] = None
|
||||
|
||||
num_prefill_tokens: int
|
||||
|
||||
num_kv_splits: int = 4 # TODO(lucas) add heuristic
|
||||
attn_logits: Optional[torch.Tensor] = None
|
||||
req_idx: Optional[torch.Tensor] = None
|
||||
|
||||
# The dimension of the attention heads
|
||||
head_dim: Optional[int] = None
|
||||
|
||||
def __post_init__(self):
|
||||
supported_head_sizes = VACCMLABackend.get_supported_head_sizes()
|
||||
if self.head_dim is not None and self.head_dim \
|
||||
not in supported_head_sizes:
|
||||
raise ValueError(
|
||||
f"Only {supported_head_sizes} are supported for head_dim,",
|
||||
f"received {self.head_dim}.")
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["VACCMLAMetadata"]:
|
||||
if self.num_prefills == 0:
|
||||
return None
|
||||
|
||||
if self._cached_prefill_metadata is not None:
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
assert self.seq_lens is not None
|
||||
assert self.seq_lens_tensor is not None
|
||||
|
||||
# Compute some attn_metadata fields which default to None
|
||||
query_start_loc = (None if self.query_start_loc is None else
|
||||
self.query_start_loc[:self.num_prefills + 1])
|
||||
slot_mapping = (None if self.slot_mapping is None else
|
||||
self.slot_mapping[:self.num_prefill_tokens])
|
||||
seq_lens = (None if self.seq_lens is None else
|
||||
self.seq_lens[:self.num_prefills])
|
||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||
self.seq_lens_tensor[:self.num_prefills])
|
||||
seq_start_loc = (None if self.seq_start_loc is None else
|
||||
self.seq_start_loc[:self.num_prefills + 1])
|
||||
context_lens_tensor = (None if self.context_lens_tensor is None else
|
||||
self.context_lens_tensor[:self.num_prefills])
|
||||
block_tables = (None if self.block_tables is None else
|
||||
self.block_tables[:self.num_prefills])
|
||||
input_positions = (None if self.input_positions is None else
|
||||
self.input_positions[:self.num_prefill_tokens])
|
||||
|
||||
self._cached_prefill_metadata = VACCMLAMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=0,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=self.
|
||||
multi_modal_placeholder_index_maps,
|
||||
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
|
||||
input_positions=input_positions,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_prefill_seq_len=None,
|
||||
max_decode_seq_len=0,
|
||||
query_start_loc=query_start_loc,
|
||||
seq_start_loc=seq_start_loc,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=False,
|
||||
head_dim=self.head_dim)
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
@property
|
||||
def decode_metadata(self) -> Optional["VACCMLAMetadata"]:
|
||||
if self.num_decode_tokens == 0:
|
||||
return None
|
||||
|
||||
if self._cached_decode_metadata is not None:
|
||||
return self._cached_decode_metadata
|
||||
assert self.seq_lens_tensor is not None
|
||||
|
||||
# Compute some attn_metadata fields which default to None
|
||||
slot_mapping = (None if self.slot_mapping is None else
|
||||
self.slot_mapping[self.num_prefill_tokens:])
|
||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||
self.seq_lens_tensor[self.num_prefills:])
|
||||
block_tables = (None if self.block_tables is None else
|
||||
self.block_tables[self.num_prefills:])
|
||||
input_positions = (None if self.input_positions is None else
|
||||
self.input_positions[self.num_prefill_tokens:])
|
||||
|
||||
self._cached_decode_metadata = VACCMLAMetadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=self.num_decode_tokens,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=True,
|
||||
seq_lens=self.seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_decode_query_len=self.max_decode_query_len,
|
||||
max_query_len=self.max_query_len,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.max_decode_seq_len,
|
||||
# Batch may be composed of prefill|decodes, adjust query start
|
||||
# indices to refer to the start of decodes. E.g.
|
||||
# in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
|
||||
query_start_loc=(self.query_start_loc[self.num_prefills:] -
|
||||
self.query_start_loc[self.num_prefills])
|
||||
if self.query_start_loc is not None else None,
|
||||
seq_start_loc=self.seq_start_loc[self.num_prefills:]
|
||||
if self.seq_start_loc is not None else None,
|
||||
context_lens_tensor=None,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=self.use_cuda_graph,
|
||||
input_positions=input_positions,
|
||||
head_dim=self.head_dim)
|
||||
return self._cached_decode_metadata
|
||||
|
||||
def advance_step(self,
|
||||
model_input: "ModelInputForVACCWithSamplingMetadata",
|
||||
sampled_token_ids: Optional[torch.Tensor],
|
||||
block_size: int,
|
||||
num_seqs: int,
|
||||
num_queries: int,
|
||||
turn_prefills_into_decodes: bool = False):
|
||||
"""
|
||||
Update metadata in-place to advance one decode step.
|
||||
"""
|
||||
# When using cudagraph, the num_seqs is padded to the next captured
|
||||
# batch sized, but num_queries tracks the actual number of requests in
|
||||
# the batch. For --enforce-eager mode, num_seqs == num_queries
|
||||
if num_seqs != num_queries:
|
||||
assert num_seqs > num_queries
|
||||
assert self.use_cuda_graph
|
||||
|
||||
if turn_prefills_into_decodes:
|
||||
# When Mutli-Step is enabled with Chunked-Prefill, prefills and
|
||||
# decodes are scheduled together. In the first step, all the
|
||||
# prefills turn into decodes. This update reflects that
|
||||
# conversion.
|
||||
assert self.num_decode_tokens + self.num_prefills == num_seqs
|
||||
self.num_decode_tokens += self.num_prefills
|
||||
self.num_prefills = 0
|
||||
# self.num_prefill_tokens = 0
|
||||
# self.max_prefill_seq_len = 0
|
||||
self.max_query_len = 1
|
||||
|
||||
self.slot_mapping = self.slot_mapping[:num_seqs]
|
||||
else:
|
||||
assert self.seq_lens is not None
|
||||
assert self.max_decode_seq_len == max(self.seq_lens)
|
||||
|
||||
assert self.num_prefills == 0
|
||||
assert self.num_prefill_tokens == 0
|
||||
assert self.num_decode_tokens == num_seqs
|
||||
assert self.slot_mapping.shape == (num_seqs, )
|
||||
|
||||
assert self.seq_lens is not None
|
||||
assert len(self.seq_lens) == num_seqs
|
||||
assert self.seq_lens_tensor is not None
|
||||
assert self.seq_lens_tensor.shape == (num_seqs, )
|
||||
# assert self.max_query_len == 1
|
||||
# assert self.max_prefill_seq_len == 0
|
||||
|
||||
assert self.query_start_loc is not None
|
||||
assert self.query_start_loc.shape == (num_queries + 1, )
|
||||
assert self.seq_start_loc is not None
|
||||
assert self.seq_start_loc.shape == (num_seqs + 1, )
|
||||
|
||||
assert self.context_lens_tensor is not None
|
||||
assert self.context_lens_tensor.shape == (num_queries, )
|
||||
|
||||
assert self.block_tables is not None
|
||||
assert self.block_tables.shape[0] == num_seqs
|
||||
|
||||
# Update query lengths. Note that we update only queries and not seqs,
|
||||
# since tensors may be padded due to captured cuda graph batch size
|
||||
for i in range(num_queries):
|
||||
self.seq_lens[i] += 1
|
||||
# self.max_decode_seq_len = None
|
||||
|
||||
ops.advance_step_flashattn(num_seqs=num_seqs,
|
||||
num_queries=num_queries,
|
||||
block_size=block_size,
|
||||
input_tokens=model_input.input_tokens,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
input_positions=model_input.input_positions,
|
||||
seq_lens=self.seq_lens_tensor,
|
||||
slot_mapping=self.slot_mapping,
|
||||
block_tables=self.block_tables)
|
||||
|
||||
|
||||
class VACCMLAMetadataBuilder(AttentionMetadataBuilder[VACCMLAMetadata]):
|
||||
|
||||
def __init__(self, input_builder: "ModelInputForVACCBuilder"):
|
||||
self.chunked_prefill = True
|
||||
if hasattr(input_builder, 'chunked_prefill'):
|
||||
self.chunked_prefill = input_builder.chunked_prefill
|
||||
self.input_builder = input_builder
|
||||
self.runner = input_builder.runner
|
||||
self.sliding_window = input_builder.sliding_window
|
||||
self.block_size = input_builder.block_size
|
||||
|
||||
def prepare(self):
|
||||
self.slot_mapping: List[int] = []
|
||||
self.prefill_seq_lens: List[int] = []
|
||||
self.context_lens: List[int] = []
|
||||
self.block_tables: List[List[int]] = []
|
||||
self.curr_seq_lens: List[int] = []
|
||||
self.input_positions: List[int] = []
|
||||
self.multimodal_placeholder_maps: Dict[
|
||||
str,
|
||||
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
|
||||
self.num_prefills = 0
|
||||
self.num_prefill_tokens = 0
|
||||
self.num_decode_tokens = 0
|
||||
self.has_prefix_cache_hit = False
|
||||
|
||||
def build(self, seq_lens: List[int], query_lens: List[int],
|
||||
cuda_graph_pad_size: int, batch_size: int):
|
||||
"""Build attention metadata with on-device tensors.
|
||||
|
||||
Args:
|
||||
seq_lens: The maybe padded sequence lengths of the input sequences.
|
||||
query_lens: The query lengths of the input sequences.
|
||||
cuda_graph_pad_size: The padding size for cuda graph.
|
||||
-1 if cuda graph is not used.
|
||||
batch_size: The maybe padded batch size.
|
||||
"""
|
||||
|
||||
self.input_data = self.input_builder.input_data
|
||||
|
||||
self.slot_mapping=self.input_data.slot_mapping
|
||||
self.context_lens= self.input_data.context_lens
|
||||
if self.input_data.num_prefill_tokens !=0:
|
||||
|
||||
self.block_tables = self.input_data.prefill_block_tables
|
||||
else:
|
||||
self.block_tables= self.input_data.decode_block_tables
|
||||
self.input_positions= self.input_data.input_positions
|
||||
|
||||
self.prefill_seq_lens = seq_lens[0:self.input_data.num_prefills]
|
||||
|
||||
self.num_prefills = self.input_data.num_prefills
|
||||
self.num_prefill_tokens = self.input_data.num_prefill_tokens
|
||||
self.num_decode_tokens = self.input_data.num_decode_tokens
|
||||
|
||||
device = self.runner.device
|
||||
use_captured_graph = cuda_graph_pad_size != -1
|
||||
|
||||
# max_query_len = max(query_lens)
|
||||
# decode_query_lens = query_lens[self.num_prefills:]
|
||||
# if len(decode_query_lens) > 0:
|
||||
# max_decode_query_len = max(decode_query_lens)
|
||||
# else:
|
||||
# max_decode_query_len = 1
|
||||
# max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
||||
# max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
||||
num_decode_tokens = self.num_decode_tokens
|
||||
query_start_loc = list(accumulate(query_lens, initial=0))
|
||||
seq_start_loc = list(accumulate(seq_lens, initial=0))
|
||||
|
||||
num_seqs = len(seq_lens)
|
||||
if use_captured_graph:
|
||||
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
|
||||
self.block_tables.extend([] * cuda_graph_pad_size)
|
||||
num_decode_tokens = batch_size - self.num_prefill_tokens
|
||||
block_tables = self._get_graph_runner_block_tables(
|
||||
num_seqs, self.block_tables)
|
||||
else:
|
||||
block_tables = make_tensor_with_pad(
|
||||
self.block_tables,
|
||||
pad=0,
|
||||
dtype=torch.int,
|
||||
device=device,
|
||||
)
|
||||
# assert max_query_len > 0, ("query_lens: {}".format(query_lens))
|
||||
|
||||
assert device is not None
|
||||
context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
|
||||
device, self.runner.pin_memory)
|
||||
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
|
||||
self.runner.pin_memory)
|
||||
input_positions = async_tensor_h2d(self.input_positions, torch.int,
|
||||
device, self.runner.pin_memory)
|
||||
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.int,
|
||||
device, self.runner.pin_memory)
|
||||
query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
|
||||
device,
|
||||
self.runner.pin_memory)
|
||||
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
|
||||
device, self.runner.pin_memory)
|
||||
placeholder_index_maps = {
|
||||
modality: placeholder_map.index_map()
|
||||
for modality, placeholder_map in
|
||||
self.multimodal_placeholder_maps.items()
|
||||
}
|
||||
return VACCMLAMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
slot_mapping=slot_mapping_tensor,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
seq_lens=seq_lens,
|
||||
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
||||
enable_kv_scales_calculation=True,
|
||||
input_positions=input_positions,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
# max_query_len=max_query_len,
|
||||
# max_decode_query_len=None,
|
||||
max_prefill_seq_len=None,
|
||||
max_decode_seq_len=None,
|
||||
query_start_loc=query_start_loc_tensor,
|
||||
seq_start_loc=seq_start_loc_tensor,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=use_captured_graph,
|
||||
num_kv_splits=4, # TODO(lucas) add heuristic
|
||||
head_dim=self.runner.model_config.get_head_size(),
|
||||
)
|
||||
|
||||
|
||||
class VACCMLAImpl(MLACommonImpl[VACCMLAMetadata]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
# MLA Specific Arguments
|
||||
**kwargs) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **kwargs)
|
||||
unsupported_features = [
|
||||
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
||||
]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"VACCMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, blocksparse_params, "
|
||||
"logits_soft_cap")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"VACCMLAImpl")
|
||||
|
||||
def extract_weights(self):
|
||||
weights = {}
|
||||
if hasattr(self, 'W_Q'):
|
||||
weights["W_Q"] = self.W_Q
|
||||
if hasattr(self, 'W_Q_scales'):
|
||||
weights["W_Q_scales"] = self.W_Q_scales
|
||||
if hasattr(self, 'W_QR'):
|
||||
weights['W_QR'] = self.W_QR
|
||||
if hasattr(self, 'W_QR_scales'):
|
||||
weights["W_QR_scales"] = self.W_QR_scales
|
||||
if hasattr(self, 'W_Q_QR'):
|
||||
weights["W_Q_QR"] = self.W_Q_QR
|
||||
if hasattr(self, 'W_Q_QR_scales'):
|
||||
weights["W_Q_QR_scales"] = self.W_Q_QR_scales
|
||||
if hasattr(self, 'W_UK'):
|
||||
weights['W_UK'] = self.W_UK
|
||||
if hasattr(self, 'W_UK_scales'):
|
||||
weights['W_UK_scales'] = self.W_UK_scales
|
||||
if hasattr(self, 'W_Q_UK_scales'):
|
||||
weights['W_Q_UK_scales'] = self.W_Q_UK_scales
|
||||
if hasattr(self, 'W_UV'):
|
||||
weights['W_UV'] = self.W_UV
|
||||
if hasattr(self, 'W_UV_scales'):
|
||||
weights['W_UV_scales'] = self.W_UV_scales
|
||||
if hasattr(self, 'W_UV_O'):
|
||||
weights['W_UV_O'] = self.W_UV_O
|
||||
if hasattr(self, 'W_UV_O_scales'):
|
||||
weights['W_UV_O_scales'] = self.W_UV_O_scales
|
||||
return weights
|
||||
|
||||
def _forward_prefill(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: VACCMLAMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert isinstance(attn_metadata, VACCMLAMetadata)
|
||||
kv_nope = self.kv_b_proj(kv_c_normed)[0]\
|
||||
.view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||
k_nope, v = kv_nope\
|
||||
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
|
||||
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
|
||||
v = v.contiguous()
|
||||
|
||||
# For MLA the v head dim is smaller than qk head dim so we pad out
|
||||
# v with 0s to match the qk head dim
|
||||
# v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
|
||||
# value=0)
|
||||
# attn_output = torch.vacc.scaled_dot_product_attention(
|
||||
# query=q,
|
||||
# key=k,
|
||||
# value=v_padded,
|
||||
# attn_mask=None,
|
||||
# dropout_p=0,
|
||||
# is_causal=True,
|
||||
# is_train=False,
|
||||
# recompute=False,
|
||||
# flash_attention=True,
|
||||
# sm_scale=self.scale
|
||||
# )
|
||||
|
||||
# attn_output = attn_output\
|
||||
# .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\
|
||||
# .reshape(-1, self.num_heads * v.shape[-1])
|
||||
seq_lens = attn_metadata.prefill_metadata.seq_lens
|
||||
if len(seq_lens) == 1:
|
||||
# Vacc supports different head dim of v and qk.
|
||||
attn_output = torch.vacc.scaled_dot_product_attention(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
attn_mask=None,
|
||||
dropout_p=0,
|
||||
is_causal=True,
|
||||
is_train=False,
|
||||
recompute=False,
|
||||
flash_attention=False,
|
||||
sm_scale=self.scale
|
||||
)
|
||||
attn_out = attn_output.view(-1, self.num_heads * v.shape[-1])
|
||||
else:
|
||||
attn_outs = []
|
||||
start = 0
|
||||
for seq in seq_lens:
|
||||
end = start + seq
|
||||
attn_out = torch.vacc.scaled_dot_product_attention(
|
||||
query=q[start:end, :],
|
||||
key=k[start:end, :],
|
||||
value=v[start:end, :],
|
||||
attn_mask=None,
|
||||
dropout_p=0,
|
||||
is_causal=True,
|
||||
is_train=False,
|
||||
recompute=False,
|
||||
flash_attention=False,
|
||||
sm_scale=self.scale
|
||||
)
|
||||
start = end
|
||||
attn_outs.append(attn_out)
|
||||
attn_out = torch.cat(attn_outs, dim=0).view(-1, self.num_heads * v.shape[-1])
|
||||
|
||||
return self.o_proj(attn_out)[0]
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: VACCMLAMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
raise NotImplementedError("FP8 Triton MLA not yet supported")
|
||||
|
||||
decode_meta = attn_metadata.decode_metadata
|
||||
assert decode_meta is not None
|
||||
B = q_nope.shape[0]
|
||||
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
o = torch.zeros(B,
|
||||
self.num_heads,
|
||||
self.kv_lora_rank,
|
||||
dtype=q.dtype,
|
||||
device=q.device)
|
||||
|
||||
# Add a head dim of 1
|
||||
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2)
|
||||
# print(f"kv_c_and_k_pe_cache: {kv_c_and_k_pe_cache.shape} ")
|
||||
kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]
|
||||
|
||||
# Run MQA using paged_attention
|
||||
# o = torch.vacc.paged_attention(
|
||||
# query=q,
|
||||
# key_cache=kv_c_and_k_pe_cache,
|
||||
# value_cache=kv_c_cache,
|
||||
# block_table=decode_meta.block_tables,
|
||||
# seq_len=decode_meta.seq_lens_tensor,
|
||||
# out=o,
|
||||
# sm_scale=self.scale
|
||||
# )
|
||||
|
||||
# Run MQA using spda
|
||||
# t0 = time.time()
|
||||
o = vacc_paged_attention_naive(
|
||||
q,
|
||||
kv_c_and_k_pe_cache,
|
||||
kv_c_cache,
|
||||
block_table = decode_meta.block_tables,
|
||||
# seq_lens = decode_meta.seq_lens_tensor,
|
||||
seq_lens=decode_meta.seq_lens,
|
||||
out = o,
|
||||
sm_scale=self.scale)
|
||||
# print(f'{os.getpid()} paged_atten(seq: {decode_meta.seq_lens}) time: {time.time() - t0}')
|
||||
|
||||
return self._v_up_proj_and_o_proj(o)
|
||||
|
||||
def vacc_paged_attention_naive(
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
block_table: torch.Tensor,
|
||||
# seq_lens: torch.Tensor,
|
||||
seq_lens: int,
|
||||
out: Optional[torch.Tensor] = None,
|
||||
sm_scale = -1
|
||||
) -> torch.Tensor:
|
||||
|
||||
# gurantee batch=1 perf
|
||||
if len(seq_lens) == 1:
|
||||
k = key_cache.view(-1, key_cache.shape[2], key_cache.shape[3])[:seq_lens[0]]
|
||||
v = value_cache.view(-1, value_cache.shape[2], value_cache.shape[3])[:seq_lens[0]]
|
||||
attn_out = torch.vacc.scaled_dot_product_attention(
|
||||
query=query,
|
||||
key=k,
|
||||
value=v,
|
||||
attn_mask=None,
|
||||
dropout_p=0,
|
||||
is_causal=False,
|
||||
is_train=False,
|
||||
recompute=False,
|
||||
flash_attention=False,
|
||||
sm_scale=sm_scale
|
||||
)
|
||||
else:
|
||||
# t0 = time.time()
|
||||
attn_outs = []
|
||||
for i in range(len(seq_lens)):
|
||||
k_slices = key_cache[block_table[i], :, :, :]
|
||||
k = torch.cat([k_slices[i, :, :, :].unsqueeze(1) for i in range(len(block_table[i]))], dim=0)
|
||||
k = k.view(-1, key_cache.shape[2], key_cache.shape[3])[:seq_lens[i]]
|
||||
v_slices = value_cache[block_table[i], :, :, :]
|
||||
v = torch.cat([v_slices[i, :, :, :].unsqueeze(1) for i in range(len(block_table[i]))], dim=0)
|
||||
v = v.view(-1, value_cache.shape[2], value_cache.shape[3])[:seq_lens[i]]
|
||||
|
||||
attn_out = torch.vacc.scaled_dot_product_attention(
|
||||
query=query[i:i+1,:,:],
|
||||
key=k,
|
||||
value=v,
|
||||
attn_mask=None,
|
||||
dropout_p=0,
|
||||
is_causal=False,
|
||||
is_train=False,
|
||||
recompute=False,
|
||||
flash_attention=False,
|
||||
sm_scale=sm_scale
|
||||
)
|
||||
attn_outs.append(attn_out)
|
||||
|
||||
attn_out = torch.cat(attn_outs, dim=0)
|
||||
# print(f'{os.getpid()} call spda(seq: {seq_lens}) time: {time.time() - t0}')
|
||||
return attn_out
|
||||
|
||||
# MLA single op impl
|
||||
def vacc_paged_attention_naive_singleop(
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
seq_lens,
|
||||
block_table = None,
|
||||
out: torch.Tensor = None,
|
||||
sm_scale = -1
|
||||
) -> torch.Tensor:
|
||||
k = key_cache.view(-1, key_cache.shape[2], key_cache.shape[3])[:seq_lens]
|
||||
v = value_cache.view(-1, value_cache.shape[2], value_cache.shape[3])[:seq_lens].squeeze(1)
|
||||
pe_cache = k[..., 512:].squeeze(1)
|
||||
print(f'q:{query[..., :512].shape} v:{v.shape} pe_cache:{pe_cache.shape}')
|
||||
q_nope_kv_c = torch.einsum("shc,tc->sht", query[..., :512], v)
|
||||
q_pe_k_pe = torch.einsum("shr,tr->sht", query[..., 512:], pe_cache)
|
||||
scores = (q_nope_kv_c + q_pe_k_pe) * sm_scale
|
||||
scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(query)
|
||||
o = torch.einsum("sht,tc->shc", scores, v)
|
||||
return o
|
||||
0
vllm_vacc/vllm/attention/ops/__init__.py
Normal file
0
vllm_vacc/vllm/attention/ops/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
160
vllm_vacc/vllm/attention/ops/vacc_paged_attn.py
Normal file
160
vllm_vacc/vllm/attention/ops/vacc_paged_attn.py
Normal file
@@ -0,0 +1,160 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
|
||||
if HAS_TRITON:
|
||||
from vllm.attention.ops.prefix_prefill import context_attention_fwd
|
||||
|
||||
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
|
||||
_PARTITION_SIZE = 512
|
||||
|
||||
@dataclass
|
||||
class PagedAttentionMetadata:
|
||||
"""Metadata for PagedAttention."""
|
||||
# (batch_size,). The length of sequences (entire tokens seen so far) per
|
||||
# sequence.
|
||||
seq_lens_tensor: Optional[torch.Tensor]
|
||||
# Maximum sequence length in the batch. 0 if it is prefill-only batch.
|
||||
max_decode_seq_len: int
|
||||
# (batch_size, max_blocks_per_seq).
|
||||
# Block addresses per sequence. (Seq id -> list of physical block)
|
||||
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
|
||||
# in the kv cache. Each block can contain up to block_size tokens.
|
||||
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
|
||||
# captured.
|
||||
block_tables: Optional[torch.Tensor]
|
||||
|
||||
class VaccPagedAttention:
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> List[int]:
|
||||
return [32, 64, 80, 96, 112, 120, 128, 192, 256]
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (2, num_blocks, block_size * num_kv_heads * head_size)
|
||||
|
||||
@staticmethod
|
||||
def split_kv_cache(
|
||||
kv_cache: torch.Tensor,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# x = 16 // kv_cache.element_size()
|
||||
num_blocks = kv_cache.shape[1]
|
||||
|
||||
key_cache = kv_cache[0]
|
||||
key_cache = key_cache.view(num_blocks, -1,num_kv_heads, head_size)
|
||||
value_cache = kv_cache[1]
|
||||
value_cache = value_cache.view(num_blocks, -1, num_kv_heads, head_size)
|
||||
return key_cache, value_cache
|
||||
|
||||
@staticmethod
|
||||
def write_to_paged_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
) -> None:
|
||||
# list_from_tensor = slot_mapping.tolist()
|
||||
torch.vacc.reshape_and_cache_attention(key,key_cache,slot_mapping)
|
||||
torch.vacc.reshape_and_cache_attention(value,value_cache,slot_mapping)
|
||||
|
||||
@staticmethod
|
||||
def forward_decode(
|
||||
output: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
max_seq_len: int,
|
||||
kv_cache_dtype: str,
|
||||
num_kv_heads: int,
|
||||
scale: float,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
tp_rank: int = 0,
|
||||
blocksparse_local_blocks: int = 0,
|
||||
blocksparse_vert_stride: int = 0,
|
||||
blocksparse_block_size: int = 64,
|
||||
blocksparse_head_sliding_step: int = 0,
|
||||
) -> torch.Tensor:
|
||||
torch.vacc.paged_attention(query,key_cache,value_cache,block_tables,seq_lens,-1,output)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def forward_prefix(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
query_start_loc: torch.Tensor,
|
||||
seq_lens_tensor: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
max_query_len: int,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
sliding_window: Optional[int],
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
) -> torch.Tensor:
|
||||
output = torch.empty_like(query)
|
||||
context_attention_fwd(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
kv_cache_dtype,
|
||||
key_cache,
|
||||
value_cache,
|
||||
block_tables,
|
||||
# query_start_loc is (batch_size + 1,)
|
||||
query_start_loc[:-1],
|
||||
seq_lens_tensor,
|
||||
context_lens,
|
||||
max_query_len,
|
||||
k_scale,
|
||||
v_scale,
|
||||
alibi_slopes,
|
||||
sliding_window,
|
||||
)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
src_key_cache = src_kv_cache[0]
|
||||
dst_key_cache = dst_kv_cache[0]
|
||||
torch.vacc.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
|
||||
|
||||
src_value_cache = src_kv_cache[1]
|
||||
dst_value_cache = dst_kv_cache[1]
|
||||
torch.vacc.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
key_caches = [kv_cache[0] for kv_cache in kv_caches]
|
||||
value_caches = [kv_cache[1] for kv_cache in kv_caches]
|
||||
torch.vacc.copy_blocks(key_caches, value_caches, src_to_dists)
|
||||
154
vllm_vacc/vllm/config.py
Normal file
154
vllm_vacc/vllm/config.py
Normal file
@@ -0,0 +1,154 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import ast
|
||||
import copy
|
||||
import enum
|
||||
import hashlib
|
||||
import inspect
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
import textwrap
|
||||
import warnings
|
||||
from collections import Counter
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import (MISSING, dataclass, field, fields, is_dataclass,
|
||||
replace)
|
||||
from importlib.util import find_spec
|
||||
from pathlib import Path
|
||||
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
|
||||
Optional, Protocol, TypeVar, Union, get_args)
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
from torch.distributed import ProcessGroup, ReduceOp
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS,
|
||||
QuantizationMethods,
|
||||
get_quantization_config)
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.platforms import CpuArchEnum, current_platform
|
||||
from vllm.config.model import _STR_DTYPE_TO_TORCH_DTYPE
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def ModelConfig___verify_quantization(self) -> None:
|
||||
supported_quantization = QUANTIZATION_METHODS
|
||||
optimized_quantization_methods = [
|
||||
"fp8", "modelopt", "gptq_marlin_24", "gptq_marlin",
|
||||
"awq_marlin", "fbgemm_fp8", "compressed-tensors", "experts_int8",
|
||||
"quark", "modelopt_fp4", "bitblas"#, "gptq_bitblas"
|
||||
]
|
||||
if self.quantization is not None:
|
||||
self.quantization = self.quantization.lower()
|
||||
|
||||
# Parse quantization method from the HF model config, if available.
|
||||
quant_cfg = self._parse_quant_hf_config(self.hf_config)
|
||||
if quant_cfg is None and (text_config := getattr(
|
||||
self.hf_config, "text_config", None)):
|
||||
# Check the text config as well for multi-modal models.
|
||||
quant_cfg = self._parse_quant_hf_config(text_config)
|
||||
|
||||
if quant_cfg is not None:
|
||||
quant_method = quant_cfg.get("quant_method", "").lower()
|
||||
quant_method = quant_method.replace("compressed_tensors",
|
||||
"compressed-tensors")
|
||||
quant_cfg["quant_method"] = quant_method
|
||||
|
||||
# Quantization methods which are overrides (i.e. they have a
|
||||
# `override_quantization_method` method) must be checked in order
|
||||
# of preference (this is particularly important for GPTQ).
|
||||
overrides = [
|
||||
# "marlin",
|
||||
"bitblas",
|
||||
"gptq_marlin_24",
|
||||
"gptq_marlin",
|
||||
# "gptq_bitblas",
|
||||
"awq_marlin",
|
||||
"ipex",
|
||||
"moe_wna16",
|
||||
"modelopt",
|
||||
"modelopt_fp4",
|
||||
"petit_nvfp4",
|
||||
]
|
||||
quantization_methods = [
|
||||
q for q in supported_quantization if q not in overrides
|
||||
]
|
||||
# Any custom overrides will be in quantization_methods so we place
|
||||
# them at the start of the list so custom overrides have preference
|
||||
# over the built in ones.
|
||||
quantization_methods = quantization_methods + overrides
|
||||
|
||||
# Detect which checkpoint is it
|
||||
for name in quantization_methods:
|
||||
method = get_quantization_config(name)
|
||||
quantization_override = method.override_quantization_method(
|
||||
quant_cfg, self.quantization)
|
||||
if quantization_override is not None:
|
||||
# Raise error if the override is not custom (custom would
|
||||
# be in QUANTIZATION_METHODS but not QuantizationMethods)
|
||||
# and hasn't been added to the overrides list.
|
||||
if (name in get_args(QuantizationMethods)
|
||||
and name not in overrides):
|
||||
raise ValueError(
|
||||
f"Quantization method {name} is an override but "
|
||||
"is has not been added to the `overrides` list "
|
||||
"above. This is necessary to ensure that the "
|
||||
"overrides are checked in order of preference.")
|
||||
quant_method = quantization_override
|
||||
self.quantization = quantization_override
|
||||
break
|
||||
|
||||
# Verify quantization configurations.
|
||||
if self.quantization is None:
|
||||
self.quantization = quant_method
|
||||
elif self.quantization != quant_method:
|
||||
raise ValueError(
|
||||
"Quantization method specified in the model config "
|
||||
f"({quant_method}) does not match the quantization "
|
||||
f"method specified in the `quantization` argument "
|
||||
f"({self.quantization}).")
|
||||
|
||||
if self.quantization is not None:
|
||||
if self.quantization not in supported_quantization:
|
||||
raise ValueError(
|
||||
f"Unknown quantization method: {self.quantization}. Must "
|
||||
f"be one of {supported_quantization}.")
|
||||
from vllm.platforms import current_platform
|
||||
current_platform.verify_quantization(self.quantization)
|
||||
if self.quantization not in optimized_quantization_methods:
|
||||
logger.warning(
|
||||
"%s quantization is not fully "
|
||||
"optimized yet. The speed can be slower than "
|
||||
"non-quantized models.", self.quantization)
|
||||
|
||||
|
||||
def _get_head_dtype(config: PretrainedConfig, dtype: torch.dtype,
|
||||
runner_type: str) -> torch.dtype:
|
||||
head_dtype: Optional[Union[str,
|
||||
torch.dtype]] = getattr(config, "head_dtype",
|
||||
None)
|
||||
|
||||
if head_dtype == "model":
|
||||
return dtype
|
||||
elif isinstance(head_dtype, str):
|
||||
head_dtype = head_dtype.lower()
|
||||
if head_dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
|
||||
raise ValueError(f"Unknown dtype: {head_dtype!r}")
|
||||
return _STR_DTYPE_TO_TORCH_DTYPE[head_dtype]
|
||||
elif isinstance(head_dtype, torch.dtype):
|
||||
return head_dtype
|
||||
elif head_dtype is None:
|
||||
if torch.float32 not in current_platform.supported_dtypes:
|
||||
return dtype
|
||||
if runner_type == "pooling":
|
||||
return torch.float16
|
||||
return dtype
|
||||
else:
|
||||
raise ValueError(f"Unknown dtype: {head_dtype}")
|
||||
52
vllm_vacc/vllm/config_manager.py
Normal file
52
vllm_vacc/vllm/config_manager.py
Normal file
@@ -0,0 +1,52 @@
|
||||
|
||||
#####################################################
|
||||
## 1. use for Memory-Recycler
|
||||
## .model_infos
|
||||
## ['deepseek_mtp',]
|
||||
##
|
||||
## 2. waitting...
|
||||
######################################################
|
||||
|
||||
import os
|
||||
class ConfigManager():
|
||||
def __init__(self):
|
||||
self._config_name = ".model_infos"
|
||||
|
||||
def update_model_infos(self, model_infos : str):
|
||||
from pathlib import Path
|
||||
workspace_path = Path.cwd()
|
||||
|
||||
bootinfo_config = f'{workspace_path}/{self._config_name}'
|
||||
try:
|
||||
with open(bootinfo_config, 'w') as w:
|
||||
w.write(model_infos)
|
||||
except Exception as e:
|
||||
print("[WARN] write model_infos fail, caused by ", e)
|
||||
raise False
|
||||
|
||||
def get_model_infos(self):
|
||||
from pathlib import Path
|
||||
workspace_path = Path.cwd()
|
||||
|
||||
bootinfo_config = f'{workspace_path}/{self._config_name}'
|
||||
bootinfo_inited = os.path.exists(bootinfo_config)
|
||||
|
||||
runner_model_infos = "default"
|
||||
if bootinfo_inited:
|
||||
try:
|
||||
with open(bootinfo_config) as w:
|
||||
runner_model_infos = w.readline()
|
||||
except Exception as e:
|
||||
print("[WARN] model_infos load fail ", e)
|
||||
|
||||
return runner_model_infos
|
||||
|
||||
config_manager = None
|
||||
|
||||
def vllm_vacc_config_manager():
|
||||
global config_manager
|
||||
|
||||
if config_manager is None:
|
||||
config_manager = ConfigManager()
|
||||
return config_manager
|
||||
|
||||
0
vllm_vacc/vllm/core/__init__.py
Normal file
0
vllm_vacc/vllm/core/__init__.py
Normal file
BIN
vllm_vacc/vllm/core/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/core/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_vacc/vllm/core/__pycache__/block_manager.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/core/__pycache__/block_manager.cpython-312.pyc
Normal file
Binary file not shown.
0
vllm_vacc/vllm/core/block/__init__.py
Normal file
0
vllm_vacc/vllm/core/block/__init__.py
Normal file
BIN
vllm_vacc/vllm/core/block/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/core/block/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
vllm_vacc/vllm/core/block/__pycache__/common.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/core/block/__pycache__/common.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
407
vllm_vacc/vllm/core/block/block_table.py
Normal file
407
vllm_vacc/vllm/core/block/block_table.py
Normal file
@@ -0,0 +1,407 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import math
|
||||
from typing import List, Optional
|
||||
|
||||
from vllm.core.block.common import BlockList
|
||||
from vllm.core.block.interfaces import Block, DeviceAwareBlockAllocator
|
||||
from vllm.utils import Device, cdiv, chunk_list
|
||||
|
||||
|
||||
class BlockTable:
|
||||
"""A class to manage blocks for a specific sequence.
|
||||
|
||||
The BlockTable maps a sequence of tokens to a list of blocks, where each
|
||||
block represents a contiguous memory allocation for a portion of the
|
||||
sequence. The blocks are managed by a DeviceAwareBlockAllocator, which is
|
||||
responsible for allocating and freeing memory for the blocks.
|
||||
|
||||
Args:
|
||||
block_size (int): The maximum number of tokens that can be stored in a
|
||||
single block.
|
||||
block_allocator (DeviceAwareBlockAllocator): The block allocator used to
|
||||
manage memory for the blocks.
|
||||
_blocks (Optional[List[Block]], optional): An optional list of existing
|
||||
blocks to initialize the BlockTable with. If not provided, an empty
|
||||
BlockTable is created.
|
||||
max_block_sliding_window (Optional[int], optional): The number of
|
||||
blocks to keep around for each sequence. If None, all blocks
|
||||
are kept (eg., when sliding window is not used).
|
||||
It should at least fit the sliding window size of the model.
|
||||
|
||||
Attributes:
|
||||
_block_size (int): The maximum number of tokens that can be stored in a
|
||||
single block.
|
||||
_allocator (DeviceAwareBlockAllocator): The block allocator used to
|
||||
manage memory for the blocks.
|
||||
_blocks (Optional[List[Block]]): The list of blocks managed by this
|
||||
BlockTable.
|
||||
_num_full_slots (int): The number of tokens currently stored in the
|
||||
blocks.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_size: int,
|
||||
block_allocator: DeviceAwareBlockAllocator,
|
||||
_blocks: Optional[List[Block]] = None,
|
||||
max_block_sliding_window: Optional[int] = None,
|
||||
):
|
||||
self._block_size = block_size
|
||||
self._allocator = block_allocator
|
||||
if _blocks is None:
|
||||
_blocks = []
|
||||
self._blocks: BlockList = BlockList(_blocks)
|
||||
|
||||
self._max_block_sliding_window = max_block_sliding_window
|
||||
self._num_full_slots = self._get_num_token_ids()
|
||||
|
||||
@staticmethod
|
||||
def get_num_required_blocks(token_ids: List[int],
|
||||
block_size: int,
|
||||
num_lookahead_slots: int = 0) -> int:
|
||||
"""Calculates the minimum number of blocks required to store a given
|
||||
sequence of token IDs along with any look-ahead slots that may be
|
||||
required (like in multi-step + chunked-prefill).
|
||||
|
||||
This assumes worst-case scenario, where every block requires a new
|
||||
allocation (e.g. ignoring prefix caching).
|
||||
|
||||
Args:
|
||||
token_ids (List[int]): The sequence of token IDs to be stored.
|
||||
block_size (int): The maximum number of tokens that can be stored in
|
||||
a single block.
|
||||
num_lookahead_slots (int): look-ahead slots that the sequence may
|
||||
require.
|
||||
|
||||
Returns:
|
||||
int: The minimum number of blocks required to store the given
|
||||
sequence of token IDs along with any required look-ahead slots.
|
||||
"""
|
||||
return cdiv(len(token_ids) + num_lookahead_slots, block_size)
|
||||
|
||||
def allocate(self,
|
||||
token_ids: List[int],
|
||||
device: Device = Device.GPU,
|
||||
extra_hash: Optional[int] = None,
|
||||
seq_id: Optional[int] = None) -> None:
|
||||
"""Allocates memory blocks for storing the given sequence of token IDs.
|
||||
|
||||
This method allocates the required number of blocks to store the given
|
||||
sequence of token IDs.
|
||||
|
||||
Args:
|
||||
token_ids (List[int]): The sequence of token IDs to be stored.
|
||||
device (Device, optional): The device on which the blocks should be
|
||||
allocated. Defaults to Device.GPU.
|
||||
extra_hash (Optional[int]): The hash value of additional
|
||||
factors, such as adapters, that influence the block hash
|
||||
in the prefixcaching block.
|
||||
"""
|
||||
assert not self._is_allocated
|
||||
assert token_ids
|
||||
blocks = self._allocate_blocks_for_token_ids(prev_block=None,
|
||||
token_ids=token_ids,
|
||||
device=device,
|
||||
extra_hash=extra_hash,
|
||||
seq_id=seq_id)
|
||||
self.update(blocks)
|
||||
self._num_full_slots = len(token_ids)
|
||||
|
||||
def update(self, blocks: List[Block]) -> None:
|
||||
"""Resets the table to the newly provided blocks
|
||||
(with their corresponding block ids)
|
||||
"""
|
||||
self._blocks.update(blocks)
|
||||
|
||||
def append_token_ids(self,
|
||||
token_ids: List[int],
|
||||
num_lookahead_slots: int = 0,
|
||||
num_computed_slots: Optional[int] = None,
|
||||
extra_hash: Optional[int] = None,
|
||||
seq_id: Optional[int] = None) -> None:
|
||||
"""Appends a sequence of token IDs to the existing blocks in the
|
||||
BlockTable.
|
||||
|
||||
This method appends the given sequence of token IDs to the existing
|
||||
blocks in the BlockTable. If there is not enough space in the existing
|
||||
blocks, new blocks are allocated using the `ensure_num_empty_slots`
|
||||
method to accommodate the additional tokens.
|
||||
|
||||
The token IDs are divided into chunks of size `block_size` (except for
|
||||
the first chunk, which may be smaller), and each chunk is appended to a
|
||||
separate block.
|
||||
|
||||
Args:
|
||||
token_ids (List[int]): The sequence of token IDs to be appended.
|
||||
num_computed_slots (Optional[int]): The number of KV cache slots
|
||||
that are already filled (computed).
|
||||
When sliding window is enabled, this is used to compute how many
|
||||
blocks to drop at the front of the sequence.
|
||||
Without sliding window, None can be passed.
|
||||
Without chunked prefill, it should be the same as
|
||||
_num_full_slots.
|
||||
extra_hash (Optional[int]): The hash value of additional
|
||||
factors such as adapters that influence the block, apart
|
||||
from the token_ids.
|
||||
"""
|
||||
assert self._is_allocated, "no blocks have been allocated"
|
||||
assert len(self._blocks) > 0
|
||||
|
||||
# Drop blocks that are no longer needed due to sliding window
|
||||
if self._max_block_sliding_window is not None:
|
||||
null_block = self._allocator.allocate_or_get_null_block()
|
||||
assert num_computed_slots is not None
|
||||
end_block_idx = (num_computed_slots //
|
||||
self._block_size) - self._max_block_sliding_window
|
||||
for idx in range(0, end_block_idx):
|
||||
b = self._blocks[idx]
|
||||
if b is not null_block:
|
||||
self._allocator.free(b)
|
||||
self._blocks[idx] = null_block
|
||||
|
||||
# Ensure there are enough empty slots for the new tokens plus
|
||||
# lookahead slots
|
||||
self.ensure_num_empty_slots(num_empty_slots=len(token_ids) +
|
||||
num_lookahead_slots,
|
||||
extra_hash=extra_hash,
|
||||
seq_id=seq_id)
|
||||
|
||||
# Update the blocks with the new tokens
|
||||
first_block_idx = self._num_full_slots // self._block_size
|
||||
token_blocks = self._chunk_token_blocks_for_append(token_ids)
|
||||
|
||||
for i, token_block in enumerate(token_blocks):
|
||||
self._blocks.append_token_ids(first_block_idx + i, token_block, seq_id=seq_id)
|
||||
|
||||
self._num_full_slots += len(token_ids)
|
||||
|
||||
def ensure_num_empty_slots(self,
|
||||
num_empty_slots: int,
|
||||
extra_hash: Optional[int] = None,
|
||||
seq_id: Optional[int] = None) -> None:
|
||||
"""Ensures that the BlockTable has at least the specified number of
|
||||
empty slots available.
|
||||
|
||||
This method checks if the BlockTable has enough empty slots (i.e.,
|
||||
available space) to accommodate the requested number of tokens. If not,
|
||||
it allocates additional blocks on the GPU to ensure that the required
|
||||
number of empty slots is available.
|
||||
|
||||
Args:
|
||||
num_empty_slots (int): The minimum number of empty slots required.
|
||||
extra_hash (Optional[int]): The hash value of additional
|
||||
factors such as adapters that influence the block, apart
|
||||
from the token_ids.
|
||||
"""
|
||||
# Currently the block table only supports
|
||||
# appending tokens to GPU blocks.
|
||||
device = Device.GPU
|
||||
assert self._is_allocated
|
||||
|
||||
if self._num_empty_slots >= num_empty_slots:
|
||||
return
|
||||
|
||||
slots_to_allocate = num_empty_slots - self._num_empty_slots
|
||||
blocks_to_allocate = cdiv(slots_to_allocate, self._block_size)
|
||||
|
||||
for _ in range(blocks_to_allocate):
|
||||
assert len(self._blocks) > 0
|
||||
self._blocks.append(
|
||||
self._allocator.allocate_mutable_block(
|
||||
prev_block=self._blocks[-1],
|
||||
device=device,
|
||||
extra_hash=extra_hash,
|
||||
seq_id=seq_id))
|
||||
|
||||
def fork(self) -> "BlockTable":
|
||||
"""Creates a new BlockTable instance with a copy of the blocks from the
|
||||
current instance.
|
||||
|
||||
This method creates a new BlockTable instance with the same block size,
|
||||
block allocator, and a copy of the blocks from the current instance. The
|
||||
new BlockTable has its own independent set of blocks, but shares the
|
||||
same underlying memory allocation with the original BlockTable.
|
||||
|
||||
Returns:
|
||||
BlockTable: A new BlockTable instance with a copy of the blocks from
|
||||
the current instance.
|
||||
"""
|
||||
assert self._is_allocated
|
||||
assert len(self._blocks) > 0
|
||||
forked_blocks = self._allocator.fork(self._blocks[-1])
|
||||
return BlockTable(
|
||||
block_size=self._block_size,
|
||||
block_allocator=self._allocator,
|
||||
_blocks=forked_blocks,
|
||||
max_block_sliding_window=self._max_block_sliding_window,
|
||||
)
|
||||
|
||||
def free(self, seq_id: Optional[int] = None) -> None:
|
||||
"""Frees the memory occupied by the blocks in the BlockTable.
|
||||
|
||||
This method iterates over all the blocks in the `_blocks` list and calls
|
||||
the `free` method of the `_allocator` object to release the memory
|
||||
occupied by each block. After freeing all the blocks, the `_blocks` list
|
||||
is set to `None`.
|
||||
"""
|
||||
self.blocks.reverse()
|
||||
for block in self.blocks:
|
||||
self._allocator.free(block, seq_id=seq_id)
|
||||
self._blocks.reset()
|
||||
|
||||
@property
|
||||
def physical_block_ids(self) -> List[int]:
|
||||
"""Returns a list of physical block indices for the blocks in the
|
||||
BlockTable.
|
||||
|
||||
This property returns a list of integers, where each integer represents
|
||||
the physical block index of a corresponding block in the `_blocks` list.
|
||||
The physical block index is a unique identifier for the memory location
|
||||
occupied by the block.
|
||||
|
||||
Returns:
|
||||
List[int]: A list of physical block indices for the blocks in the
|
||||
BlockTable.
|
||||
"""
|
||||
return self._blocks.ids()
|
||||
|
||||
def get_unseen_token_ids(self, sequence_token_ids: List[int]) -> List[int]:
|
||||
"""Get the number of "unseen" tokens in the sequence.
|
||||
|
||||
Unseen tokens are tokens in the sequence corresponding to this block
|
||||
table, but are not yet appended to this block table.
|
||||
|
||||
Args:
|
||||
sequence_token_ids (List[int]): The list of token ids in the
|
||||
sequence.
|
||||
|
||||
Returns:
|
||||
List[int]: The postfix of sequence_token_ids that has not yet been
|
||||
appended to the block table.
|
||||
"""
|
||||
|
||||
# Since the block table is append-only, the unseen token ids are the
|
||||
# ones after the appended ones.
|
||||
return sequence_token_ids[self.num_full_slots:]
|
||||
|
||||
def _allocate_blocks_for_token_ids(
|
||||
self,
|
||||
prev_block: Optional[Block],
|
||||
token_ids: List[int],
|
||||
device: Device,
|
||||
extra_hash: Optional[int] = None,
|
||||
seq_id: Optional[int] = None) -> List[Block]:
|
||||
blocks: List[Block] = []
|
||||
|
||||
block_token_ids = []
|
||||
tail_token_ids = []
|
||||
for cur_token_ids in chunk_list(token_ids, self._block_size):
|
||||
if len(cur_token_ids) == self._block_size:
|
||||
block_token_ids.append(cur_token_ids)
|
||||
else:
|
||||
tail_token_ids.append(cur_token_ids)
|
||||
|
||||
if block_token_ids:
|
||||
blocks.extend(
|
||||
self._allocator.allocate_immutable_blocks(
|
||||
prev_block,
|
||||
block_token_ids=block_token_ids,
|
||||
device=device,
|
||||
extra_hash=extra_hash,
|
||||
seq_id=seq_id))
|
||||
prev_block = blocks[-1]
|
||||
|
||||
if tail_token_ids:
|
||||
assert len(tail_token_ids) == 1
|
||||
cur_token_ids = tail_token_ids[0]
|
||||
|
||||
block = self._allocator.allocate_mutable_block(
|
||||
prev_block=prev_block, device=device, extra_hash=extra_hash, seq_id=seq_id)
|
||||
block.append_token_ids(cur_token_ids, seq_id)
|
||||
|
||||
blocks.append(block)
|
||||
|
||||
return blocks
|
||||
|
||||
def _get_all_token_ids(self) -> List[int]:
|
||||
# NOTE: This function is O(seq_len); use sparingly.
|
||||
token_ids: List[int] = []
|
||||
|
||||
if not self._is_allocated:
|
||||
return token_ids
|
||||
|
||||
for block in self.blocks:
|
||||
token_ids.extend(block.token_ids)
|
||||
|
||||
return token_ids
|
||||
|
||||
def _get_num_token_ids(self) -> int:
|
||||
res = 0
|
||||
for block in self.blocks:
|
||||
res += len(block.token_ids)
|
||||
|
||||
return res
|
||||
|
||||
@property
|
||||
def _is_allocated(self) -> bool:
|
||||
return len(self._blocks) > 0
|
||||
|
||||
@property
|
||||
def blocks(self) -> List[Block]:
|
||||
return self._blocks.list()
|
||||
|
||||
@property
|
||||
def _num_empty_slots(self) -> int:
|
||||
assert self._is_allocated
|
||||
return len(self._blocks) * self._block_size - self._num_full_slots
|
||||
|
||||
@property
|
||||
def num_full_slots(self) -> int:
|
||||
"""Returns the total number of tokens currently stored in the
|
||||
BlockTable.
|
||||
|
||||
Returns:
|
||||
int: The total number of tokens currently stored in the BlockTable.
|
||||
"""
|
||||
return self._num_full_slots
|
||||
|
||||
def get_num_blocks_touched_by_append_slots(
|
||||
self, token_ids: List[int], num_lookahead_slots: int) -> int:
|
||||
"""Determine how many blocks will be "touched" by appending the token
|
||||
ids.
|
||||
|
||||
This is required for the scheduler to determine whether a sequence can
|
||||
continue generation, or if it must be preempted.
|
||||
"""
|
||||
# Math below is equivalent to:
|
||||
# all_token_ids = token_ids + [-1] * num_lookahead_slots
|
||||
# token_blocks = self._chunk_token_blocks_for_append(all_token_ids)
|
||||
# return len(token_blocks)
|
||||
|
||||
num_token_ids = len(token_ids) + num_lookahead_slots
|
||||
first_chunk_size = self._block_size - (self._num_full_slots %
|
||||
self._block_size)
|
||||
num_token_blocks = (1 + math.ceil(
|
||||
(num_token_ids - first_chunk_size) / self._block_size))
|
||||
return num_token_blocks
|
||||
|
||||
def _chunk_token_blocks_for_append(
|
||||
self, token_ids: List[int]) -> List[List[int]]:
|
||||
"""Split the token ids into block-sized chunks so they can be easily
|
||||
appended to blocks. The first such "token block" may have less token ids
|
||||
than the block size, since the last allocated block may be partially
|
||||
full.
|
||||
|
||||
If no token ids are provided, then no chunks are returned.
|
||||
"""
|
||||
|
||||
if not token_ids:
|
||||
return []
|
||||
|
||||
first_chunk_size = self._block_size - (self._num_full_slots %
|
||||
self._block_size)
|
||||
token_blocks = [token_ids[:first_chunk_size]]
|
||||
token_blocks.extend(
|
||||
chunk_list(token_ids[first_chunk_size:], self._block_size))
|
||||
return token_blocks
|
||||
21
vllm_vacc/vllm/core/block/common.py
Normal file
21
vllm_vacc/vllm/core/block/common.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from typing import Deque, Dict, Iterable, List, Optional, Protocol, Tuple
|
||||
|
||||
from vllm.core.block.interfaces import Block, BlockAllocator
|
||||
|
||||
BlockId = int
|
||||
RefCount = int
|
||||
|
||||
class BlockList:
|
||||
def append_token_ids(self, block_index: int, token_ids: List[int], seq_id: Optional[int]=None) -> None:
|
||||
block = self._blocks[block_index]
|
||||
prev_block_id = block.block_id
|
||||
|
||||
block.append_token_ids(token_ids, seq_id=seq_id)
|
||||
|
||||
# CoW or promotion may update the internal block_id
|
||||
if prev_block_id != block.block_id:
|
||||
self._update_block_id(block_index, block.block_id)
|
||||
373
vllm_vacc/vllm/core/block/cpu_gpu_block_allocator.py
Normal file
373
vllm_vacc/vllm/core/block/cpu_gpu_block_allocator.py
Normal file
@@ -0,0 +1,373 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Dict, FrozenSet, List, Optional, Tuple
|
||||
|
||||
from vllm.core.block.interfaces import (Block, BlockAllocator, BlockId,
|
||||
DeviceAwareBlockAllocator)
|
||||
# from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator
|
||||
from vllm.core.block.naive_block import NaiveBlock
|
||||
from vllm_vacc.vllm.core.block.naive_block import NaiveBlockAllocator
|
||||
from vllm_vacc.vllm.core.block.prefix_caching_block import PrefixCachingBlockAllocator
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import Device
|
||||
|
||||
from vllm.core.block.cpu_gpu_block_allocator import NullBlock
|
||||
|
||||
class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
|
||||
"""A block allocator that can allocate blocks on both CPU and GPU memory.
|
||||
|
||||
This class implements the `DeviceAwareBlockAllocator` interface and provides
|
||||
functionality for allocating and managing blocks of memory on both CPU and
|
||||
GPU devices.
|
||||
|
||||
The `CpuGpuBlockAllocator` maintains separate memory pools for CPU and GPU
|
||||
blocks, and allows for allocation, deallocation, forking, and swapping of
|
||||
blocks across these memory pools.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
allocator_type: str,
|
||||
num_gpu_blocks: int,
|
||||
num_cpu_blocks: int,
|
||||
block_size: int,
|
||||
) -> DeviceAwareBlockAllocator:
|
||||
"""Creates a CpuGpuBlockAllocator instance with the specified
|
||||
configuration.
|
||||
|
||||
This static method creates and returns a CpuGpuBlockAllocator instance
|
||||
based on the provided parameters. It initializes the CPU and GPU block
|
||||
allocators with the specified number of blocks, block size, and
|
||||
allocator type.
|
||||
|
||||
Args:
|
||||
allocator_type (str): The type of block allocator to use for CPU
|
||||
and GPU blocks. Currently supported values are "naive" and
|
||||
"prefix_caching".
|
||||
num_gpu_blocks (int): The number of blocks to allocate for GPU
|
||||
memory.
|
||||
num_cpu_blocks (int): The number of blocks to allocate for CPU
|
||||
memory.
|
||||
block_size (int): The size of each block in number of tokens.
|
||||
|
||||
Returns:
|
||||
DeviceAwareBlockAllocator: A CpuGpuBlockAllocator instance with the
|
||||
specified configuration.
|
||||
|
||||
Notes:
|
||||
- The block IDs are assigned contiguously, with GPU block IDs coming
|
||||
before CPU block IDs.
|
||||
"""
|
||||
# For HPU, block id 0 is used only for padding
|
||||
reserved_blocks = 1 if current_platform.is_hpu() else 0
|
||||
block_ids = list(
|
||||
range(reserved_blocks, num_gpu_blocks + num_cpu_blocks))
|
||||
num_gpu_blocks -= reserved_blocks
|
||||
gpu_block_ids = block_ids[:num_gpu_blocks]
|
||||
cpu_block_ids = block_ids[num_gpu_blocks:]
|
||||
|
||||
if allocator_type == "naive":
|
||||
gpu_allocator: BlockAllocator = NaiveBlockAllocator(
|
||||
create_block=NaiveBlock, # type: ignore
|
||||
num_blocks=num_gpu_blocks,
|
||||
block_size=block_size,
|
||||
block_ids=gpu_block_ids,
|
||||
)
|
||||
|
||||
cpu_allocator: BlockAllocator = NaiveBlockAllocator(
|
||||
create_block=NaiveBlock, # type: ignore
|
||||
num_blocks=num_cpu_blocks,
|
||||
block_size=block_size,
|
||||
block_ids=cpu_block_ids,
|
||||
)
|
||||
elif allocator_type == "prefix_caching":
|
||||
gpu_allocator = PrefixCachingBlockAllocator(
|
||||
num_blocks=num_gpu_blocks,
|
||||
block_size=block_size,
|
||||
block_ids=gpu_block_ids,
|
||||
)
|
||||
|
||||
cpu_allocator = PrefixCachingBlockAllocator(
|
||||
num_blocks=num_cpu_blocks,
|
||||
block_size=block_size,
|
||||
block_ids=cpu_block_ids,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown allocator type {allocator_type=}")
|
||||
|
||||
return CpuGpuBlockAllocator(
|
||||
cpu_block_allocator=cpu_allocator,
|
||||
gpu_block_allocator=gpu_allocator,
|
||||
)
|
||||
|
||||
def __init__(self, cpu_block_allocator: BlockAllocator,
|
||||
gpu_block_allocator: BlockAllocator):
|
||||
assert not (
|
||||
cpu_block_allocator.all_block_ids
|
||||
& gpu_block_allocator.all_block_ids
|
||||
), "cpu and gpu block allocators can't have intersection of block ids"
|
||||
|
||||
self._allocators = {
|
||||
Device.CPU: cpu_block_allocator,
|
||||
Device.GPU: gpu_block_allocator,
|
||||
}
|
||||
|
||||
self._swap_mapping: Dict[int, int] = {}
|
||||
self._null_block: Optional[Block] = None
|
||||
|
||||
self._block_ids_to_allocator: Dict[int, BlockAllocator] = {}
|
||||
for _, allocator in self._allocators.items():
|
||||
for block_id in allocator.all_block_ids:
|
||||
self._block_ids_to_allocator[block_id] = allocator
|
||||
|
||||
def allocate_or_get_null_block(self) -> Block:
|
||||
if self._null_block is None:
|
||||
self._null_block = NullBlock(
|
||||
self.allocate_mutable_block(None, Device.GPU))
|
||||
return self._null_block
|
||||
|
||||
def allocate_mutable_block(self,
|
||||
prev_block: Optional[Block],
|
||||
device: Device,
|
||||
extra_hash: Optional[int] = None,
|
||||
seq_id: Optional[int] = None) -> Block:
|
||||
"""Allocates a new mutable block on the specified device.
|
||||
|
||||
Args:
|
||||
prev_block (Optional[Block]): The previous block to in the sequence.
|
||||
Used for prefix hashing.
|
||||
device (Device): The device on which to allocate the new block.
|
||||
extra_hash (Optional[int]): The hash value of additional
|
||||
factors, such as adapters, that influence the block hash
|
||||
in the prefix caching block.
|
||||
|
||||
Returns:
|
||||
Block: The newly allocated mutable block.
|
||||
"""
|
||||
return self._allocators[device].allocate_mutable_block(
|
||||
prev_block, extra_hash=extra_hash, seq_id=seq_id)
|
||||
|
||||
def allocate_immutable_blocks(
|
||||
self,
|
||||
prev_block: Optional[Block],
|
||||
block_token_ids: List[List[int]],
|
||||
device: Device,
|
||||
extra_hash: Optional[int] = None,
|
||||
seq_id: Optional[int] = None) -> List[Block]:
|
||||
"""Allocates a new group of immutable blocks with the provided block
|
||||
token IDs on the specified device.
|
||||
|
||||
Args:
|
||||
prev_block (Optional[Block]): The previous block in the sequence.
|
||||
Used for prefix hashing.
|
||||
block_token_ids (List[int]): The list of block token IDs to be
|
||||
stored in the new blocks.
|
||||
device (Device): The device on which to allocate the new block.
|
||||
extra_hash (Optional[int]): The hash value of additional
|
||||
factors, such as adapters, that influence the block hash
|
||||
in the prefix caching block.
|
||||
|
||||
Returns:
|
||||
List[Block]: The newly allocated list of immutable blocks
|
||||
containing the provided block token IDs.
|
||||
"""
|
||||
return self._allocators[device].allocate_immutable_blocks(
|
||||
prev_block, block_token_ids, extra_hash=extra_hash, seq_id=seq_id)
|
||||
|
||||
def allocate_immutable_block(self,
|
||||
prev_block: Optional[Block],
|
||||
token_ids: List[int],
|
||||
device: Device,
|
||||
extra_hash: Optional[int] = None) -> Block:
|
||||
"""Allocates a new immutable block with the provided token IDs on the
|
||||
specified device.
|
||||
|
||||
Args:
|
||||
prev_block (Optional[Block]): The previous block in the sequence.
|
||||
Used for prefix hashing.
|
||||
token_ids (List[int]): The list of token IDs to be stored in the new
|
||||
block.
|
||||
device (Device): The device on which to allocate the new block.
|
||||
extra_hash (Optional[int]): The hash value of additional
|
||||
factors, such as adapters, that influence the block hash
|
||||
in the prefix caching block.
|
||||
|
||||
Returns:
|
||||
Block: The newly allocated immutable block containing the provided
|
||||
token IDs.
|
||||
"""
|
||||
return self._allocators[device].allocate_immutable_block(
|
||||
prev_block, token_ids, extra_hash=extra_hash)
|
||||
|
||||
def free(self, block: Block, seq_id: Optional[int] = None) -> None:
|
||||
"""Frees the memory occupied by the given block.
|
||||
|
||||
Args:
|
||||
block (Block): The block to be freed.
|
||||
"""
|
||||
# Null block should never be freed
|
||||
if isinstance(block, NullBlock):
|
||||
return
|
||||
block_id = block.block_id
|
||||
assert block_id is not None
|
||||
allocator = self._block_ids_to_allocator[block_id]
|
||||
allocator.free(block, seq_id=seq_id)
|
||||
|
||||
def fork(self, last_block: Block) -> List[Block]:
|
||||
"""Creates a new sequence of blocks that shares the same underlying
|
||||
memory as the original sequence.
|
||||
|
||||
Args:
|
||||
last_block (Block): The last block in the original sequence.
|
||||
|
||||
Returns:
|
||||
List[Block]: A new list of blocks that shares the same memory as the
|
||||
original sequence.
|
||||
"""
|
||||
# do not attempt to fork the null block
|
||||
assert not isinstance(last_block, NullBlock)
|
||||
block_id = last_block.block_id
|
||||
assert block_id is not None
|
||||
allocator = self._block_ids_to_allocator[block_id]
|
||||
return allocator.fork(last_block)
|
||||
|
||||
def get_num_free_blocks(self, device: Device, seq_id: int=None) -> int:
|
||||
"""Returns the number of free blocks available on the specified device.
|
||||
|
||||
Args:
|
||||
device (Device): The device for which to query the number of free
|
||||
blocks. AssertionError is raised if None is passed.
|
||||
|
||||
Returns:
|
||||
int: The number of free blocks available on the specified device.
|
||||
"""
|
||||
return self._allocators[device].get_num_free_blocks(seq_id=seq_id)
|
||||
|
||||
def get_num_total_blocks(self, device: Device, seq_id: int=None) -> int:
|
||||
return self._allocators[device].get_num_total_blocks(seq_id=seq_id)
|
||||
|
||||
def get_physical_block_id(self, device: Device, absolute_id: int) -> int:
|
||||
"""Returns the zero-offset block id on certain device given the
|
||||
absolute block id.
|
||||
|
||||
Args:
|
||||
device (Device): The device for which to query relative block id.
|
||||
absolute_id (int): The absolute block id for the block in
|
||||
whole allocator.
|
||||
|
||||
Returns:
|
||||
int: The zero-offset block id on certain device.
|
||||
"""
|
||||
return self._allocators[device].get_physical_block_id(absolute_id)
|
||||
|
||||
def swap(self, blocks: List[Block], src_device: Device,
|
||||
dst_device: Device) -> Dict[int, int]:
|
||||
"""Execute the swap for the given blocks from source_device
|
||||
on to dest_device, save the current swap mapping and append
|
||||
them to the accumulated `self._swap_mapping` for each
|
||||
scheduling move.
|
||||
|
||||
Args:
|
||||
blocks: List of blocks to be swapped.
|
||||
src_device (Device): Device to swap the 'blocks' from.
|
||||
dst_device (Device): Device to swap the 'blocks' to.
|
||||
|
||||
Returns:
|
||||
Dict[int, int]: Swap mapping from source_device
|
||||
on to dest_device.
|
||||
"""
|
||||
src_block_ids = [block.block_id for block in blocks]
|
||||
self._allocators[src_device].swap_out(blocks)
|
||||
self._allocators[dst_device].swap_in(blocks)
|
||||
dst_block_ids = [block.block_id for block in blocks]
|
||||
|
||||
current_swap_mapping: Dict[int, int] = {}
|
||||
for src_block_id, dst_block_id in zip(src_block_ids, dst_block_ids):
|
||||
if src_block_id is not None and dst_block_id is not None:
|
||||
self._swap_mapping[src_block_id] = dst_block_id
|
||||
current_swap_mapping[src_block_id] = dst_block_id
|
||||
return current_swap_mapping
|
||||
|
||||
def get_num_full_blocks_touched(self, blocks: List[Block],
|
||||
device: Device) -> int:
|
||||
"""Returns the number of full blocks that will be touched by
|
||||
swapping in/out the given blocks on to the 'device'.
|
||||
|
||||
Args:
|
||||
blocks: List of blocks to be swapped.
|
||||
device (Device): Device to swap the 'blocks' on.
|
||||
|
||||
Returns:
|
||||
int: the number of full blocks that will be touched by
|
||||
swapping in/out the given blocks on to the 'device'.
|
||||
Non full blocks are ignored when deciding the number
|
||||
of blocks to touch.
|
||||
"""
|
||||
return self._allocators[device].get_num_full_blocks_touched(blocks)
|
||||
|
||||
def clear_copy_on_writes(self, seq_id: Optional[int] = None) -> List[Tuple[int, int]]:
|
||||
"""Clears the copy-on-write (CoW) state and returns the mapping of
|
||||
source to destination block IDs.
|
||||
|
||||
Returns:
|
||||
List[Tuple[int, int]]: A list mapping source block IDs to
|
||||
destination block IDs.
|
||||
"""
|
||||
# CoW only supported on GPU
|
||||
device = Device.GPU
|
||||
return self._allocators[device].clear_copy_on_writes()
|
||||
|
||||
def mark_blocks_as_accessed(self, block_ids: List[int],
|
||||
now: float) -> None:
|
||||
"""Mark blocks as accessed, only use for prefix caching."""
|
||||
# Prefix caching only supported on GPU.
|
||||
device = Device.GPU
|
||||
return self._allocators[device].mark_blocks_as_accessed(block_ids, now)
|
||||
|
||||
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
|
||||
"""Mark blocks as accessed, only use for prefix caching."""
|
||||
# Prefix caching only supported on GPU.
|
||||
device = Device.GPU
|
||||
return self._allocators[device].mark_blocks_as_computed(block_ids)
|
||||
|
||||
def get_common_computed_block_ids(
|
||||
self, computed_seq_block_ids: List[List[int]]) -> List[int]:
|
||||
# Prefix caching only supported on GPU.
|
||||
device = Device.GPU
|
||||
return self._allocators[device].get_common_computed_block_ids(
|
||||
computed_seq_block_ids)
|
||||
|
||||
@property
|
||||
def all_block_ids(self) -> FrozenSet[int]:
|
||||
return frozenset(self._block_ids_to_allocator.keys())
|
||||
|
||||
def get_prefix_cache_hit_rate(self, device: Device) -> float:
|
||||
"""Prefix cache hit rate. -1 means not supported or disabled."""
|
||||
assert device in self._allocators
|
||||
return self._allocators[device].get_prefix_cache_hit_rate()
|
||||
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
"""Reset prefix cache for all devices."""
|
||||
success = True
|
||||
for allocator in self._allocators.values():
|
||||
success = success and allocator.reset_prefix_cache()
|
||||
return success
|
||||
|
||||
def get_and_reset_swaps(self) -> List[Tuple[int, int]]:
|
||||
"""Returns and clears the mapping of source to destination block IDs.
|
||||
Will be called after every swapping operations for now, and after every
|
||||
schedule when BlockManagerV2 become default. Currently not useful.
|
||||
|
||||
Returns:
|
||||
List[Tuple[int, int]]: A mapping of source to destination block IDs.
|
||||
"""
|
||||
mapping = self._swap_mapping.copy()
|
||||
self._swap_mapping.clear()
|
||||
return list(mapping.items())
|
||||
|
||||
def find_cached_blocks_prefix(
|
||||
self,
|
||||
block_hashes: List[int],
|
||||
device: Device = Device.GPU,
|
||||
) -> List[int]:
|
||||
return self._allocators[device].find_cached_blocks_prefix(block_hashes)
|
||||
464
vllm_vacc/vllm/core/block/naive_block.py
Normal file
464
vllm_vacc/vllm/core/block/naive_block.py
Normal file
@@ -0,0 +1,464 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from collections import deque
|
||||
from typing import Deque, FrozenSet, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
from vllm.core.block.common import (BlockPool, CopyOnWriteTracker, RefCounter,
|
||||
get_all_blocks_recursively)
|
||||
from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device
|
||||
|
||||
from typing import Dict
|
||||
from vllm.logger import init_logger
|
||||
import os
|
||||
|
||||
max_seq_num = int(os.getenv("MAX_SEQ_NUM", 4))
|
||||
from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size
|
||||
|
||||
Refcount = int
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
class NaiveBlockAllocator(BlockAllocator):
|
||||
"""A simple block allocator that manages blocks of memory without prefix
|
||||
caching.
|
||||
|
||||
Args:
|
||||
create_block (Block.Factory): A factory function for creating new
|
||||
blocks. This is used when a NaiveBlockAllocator is composed within
|
||||
a prefix caching allocator -- the naive block allocator must
|
||||
construct prefix caching blocks (but shouldn't know anything else
|
||||
about them).
|
||||
num_blocks (int): The total number of blocks to manage.
|
||||
block_size (int): The size of each block in tokens.
|
||||
block_ids (Optional[Iterable[int]], optional): An optional iterable of
|
||||
block IDs. If not provided, block IDs will be assigned sequentially
|
||||
from 0 to num_blocks - 1.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
create_block: Block.Factory,
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
block_ids: Optional[Iterable[int]] = None,
|
||||
block_pool: Optional[BlockPool] = None,
|
||||
):
|
||||
# new mapping seqid : block_group_id
|
||||
self.is_partitioned = False
|
||||
self.num_blocks = num_blocks
|
||||
self.seq_mapping: Dict[int, List[int]] = {}
|
||||
|
||||
if block_ids is None:
|
||||
block_ids = range(num_blocks)
|
||||
|
||||
self._free_block_indices_all: Deque[BlockId] = deque(block_ids)
|
||||
self._all_block_indices = frozenset(block_ids)
|
||||
assert len(self._all_block_indices) == num_blocks
|
||||
|
||||
self._refcounter = RefCounter(
|
||||
all_block_indices=self._free_block_indices_all)
|
||||
self._block_size = block_size
|
||||
|
||||
self._cow_tracker = CopyOnWriteTracker(
|
||||
refcounter=self._refcounter.as_readonly())
|
||||
|
||||
if block_pool is None:
|
||||
extra_factor = 4
|
||||
# Pre-allocate "num_blocks * extra_factor" block objects.
|
||||
# The "* extra_factor" is a buffer to allow more block objects
|
||||
# than physical blocks
|
||||
self._block_pool = BlockPool(self._block_size, create_block, self,
|
||||
num_blocks * extra_factor)
|
||||
else:
|
||||
# In this case, the block pool is provided by the caller,
|
||||
# which means that there is most likely a need to share
|
||||
# a block pool between allocators
|
||||
self._block_pool = block_pool
|
||||
|
||||
# partition blocks to block groups
|
||||
self.partition_blocks()
|
||||
|
||||
def allocate_immutable_block(self,
|
||||
prev_block: Optional[Block],
|
||||
token_ids: List[int],
|
||||
extra_hash: Optional[int] = None,
|
||||
device: Optional[Device] = None) -> Block:
|
||||
"""Allocates a new immutable block with the given token IDs, linked to
|
||||
the previous block.
|
||||
|
||||
Args:
|
||||
prev_block (Optional[Block]): The previous block in the sequence. If
|
||||
None, then the block to be allocated is the first block in the
|
||||
sequence.
|
||||
token_ids (List[int]): The token IDs to be stored in the new block.
|
||||
|
||||
Returns:
|
||||
Block: The newly allocated immutable block.
|
||||
"""
|
||||
assert device is None
|
||||
block = self.allocate_mutable_block(prev_block=prev_block)
|
||||
block.append_token_ids(token_ids)
|
||||
return block
|
||||
|
||||
def allocate_immutable_blocks(
|
||||
self,
|
||||
prev_block: Optional[Block],
|
||||
block_token_ids: List[List[int]],
|
||||
extra_hash: Optional[int] = None,
|
||||
device: Optional[Device] = None,
|
||||
seq_id: Optional[int] = None) -> List[Block]:
|
||||
assert device is None
|
||||
num_blocks = len(block_token_ids)
|
||||
|
||||
block_ids = []
|
||||
for i in range(num_blocks):
|
||||
block_ids.append(self._allocate_block_id(seq_id=seq_id))
|
||||
|
||||
blocks = []
|
||||
for i in range(num_blocks):
|
||||
prev_block = self._block_pool.init_block(
|
||||
prev_block=prev_block,
|
||||
token_ids=block_token_ids[i],
|
||||
block_size=self._block_size,
|
||||
physical_block_id=block_ids[i])
|
||||
blocks.append(prev_block)
|
||||
|
||||
return blocks
|
||||
|
||||
def allocate_mutable_block(self,
|
||||
prev_block: Optional[Block],
|
||||
extra_hash: Optional[int] = None,
|
||||
device: Optional[Device] = None,
|
||||
seq_id: Optional[int] = None) -> Block:
|
||||
"""Allocates a new mutable block, linked to the previous block.
|
||||
|
||||
Args:
|
||||
prev_block (Optional[Block]): The previous block in the sequence. If
|
||||
None, then the block to be allocated is the first block in the
|
||||
sequence.
|
||||
|
||||
Returns:
|
||||
Block: The newly allocated mutable block.
|
||||
"""
|
||||
assert device is None
|
||||
assert seq_id is not None
|
||||
block_id = self._allocate_block_id(seq_id)
|
||||
block = self._block_pool.init_block(prev_block=prev_block,
|
||||
token_ids=[],
|
||||
block_size=self._block_size,
|
||||
physical_block_id=block_id)
|
||||
return block
|
||||
|
||||
def _allocate_block_id(self, seq_id: Optional[int] = None) -> BlockId:
|
||||
assert seq_id is not None
|
||||
# always use the lastest grp id to allocate
|
||||
# since the previous grps are exausted
|
||||
grp_id_list = self.get_group_list(seq_id)
|
||||
if len(grp_id_list) == 0 or len(self._free_block_indices[grp_id_list[-1]]) == 0:
|
||||
if not self._free_block_grp_indices:
|
||||
# no more block id in block group pool
|
||||
# should not reach here
|
||||
raise False
|
||||
# raise BlockAllocator.NoFreeBlocksError()
|
||||
else:
|
||||
# pop a new block and add to seq_mapping
|
||||
grp_id = self._free_block_grp_indices.popleft()
|
||||
grp_id_list.append(grp_id)
|
||||
self.seq_mapping[seq_id] = grp_id_list
|
||||
grp_id = grp_id_list[-1]
|
||||
block_id = self._free_block_indices[grp_id].popleft()
|
||||
self._refcounter.incr(block_id)
|
||||
return block_id
|
||||
|
||||
def _free_block_id(self, block: Union[Block, BlockId], seq_id: Optional[int] = None) -> None:
|
||||
assert seq_id is not None
|
||||
grp_id_list = self.get_group_list(seq_id)
|
||||
if isinstance(block, Block):
|
||||
block_id = block.block_id
|
||||
block.block_id = None
|
||||
else:
|
||||
block_id = block
|
||||
assert block_id is not None
|
||||
|
||||
# block_id should always be in grp_id_list[0]
|
||||
# since the block id is freed in block id order
|
||||
grp_id = grp_id_list[-1]
|
||||
assert block_id in self._block_grp_indices[grp_id], f"grp_id: {grp_id} block_id:{block_id}"
|
||||
|
||||
refcount = self._refcounter.decr(block_id)
|
||||
if refcount == 0:
|
||||
self._free_block_indices[grp_id].appendleft(block_id)
|
||||
if len(self._free_block_indices[grp_id]) == len(self._block_grp_indices[grp_id]):
|
||||
# free group
|
||||
self.seq_mapping[seq_id].remove(grp_id)
|
||||
if len(self.seq_mapping[seq_id]) == 0:
|
||||
# free seq_id
|
||||
del self.seq_mapping[seq_id]
|
||||
# collect back to block group pool
|
||||
self._free_block_grp_indices.appendleft(grp_id)
|
||||
|
||||
def free(self, block: Block, keep_block_object: bool = False, seq_id: Optional[int] = None) -> None:
|
||||
# Release the physical block id
|
||||
self._free_block_id(block, seq_id=seq_id)
|
||||
|
||||
# Release the block object
|
||||
if not keep_block_object:
|
||||
self._block_pool.free_block(block)
|
||||
|
||||
def free_block_id(self, block_id: BlockId, seq_id: Optional[int] = None) -> None:
|
||||
assert seq_id is not None
|
||||
self._free_block_id(block_id, seq_id)
|
||||
|
||||
def fork(self, last_block: Block, seq_id: Optional[int] = None) -> List[Block]:
|
||||
"""Creates a new sequence of blocks that shares the same underlying
|
||||
memory as the original sequence.
|
||||
|
||||
Args:
|
||||
last_block (Block): The last block in the original sequence.
|
||||
|
||||
Returns:
|
||||
List[Block]: The new sequence of blocks that shares the same memory
|
||||
as the original sequence.
|
||||
"""
|
||||
source_blocks = get_all_blocks_recursively(last_block)
|
||||
|
||||
forked_blocks: List[Block] = []
|
||||
prev_block = None
|
||||
grp_id_list = self.get_group_list(seq_id)
|
||||
for block in source_blocks:
|
||||
# Increment refcount for each block.
|
||||
assert block.block_id is not None
|
||||
grp_id = self.get_group_id(block.block_id, grp_id_list)
|
||||
assert grp_id != -1, "can't locate block group"
|
||||
refcount = self._refcounter.incr(block.block_id)
|
||||
assert refcount != 1, "can't fork free'd block"
|
||||
|
||||
forked_block = self._block_pool.init_block(
|
||||
prev_block=prev_block,
|
||||
token_ids=block.token_ids,
|
||||
block_size=self._block_size,
|
||||
physical_block_id=block.block_id)
|
||||
|
||||
forked_blocks.append(forked_block)
|
||||
prev_block = forked_blocks[-1]
|
||||
|
||||
return forked_blocks
|
||||
|
||||
def partition_blocks(self) -> None:
|
||||
# only parition once in each vllm server lifecycle
|
||||
if self.is_partitioned: #and len(self.seq_mapping) > 0:
|
||||
return
|
||||
|
||||
self.is_partitioned = True
|
||||
self._blk_num_per_grp = env_blk_grp_size // self._block_size
|
||||
self._all_blk_grp_num = self.num_blocks // self._blk_num_per_grp
|
||||
|
||||
block_groups = []
|
||||
for i in range(self._all_blk_grp_num):
|
||||
start = i * self._blk_num_per_grp
|
||||
block_groups.append([k for k in range(start, start + self._blk_num_per_grp)])
|
||||
|
||||
self._free_block_grp_indices: Deque[BlockId] = deque(range(self._all_blk_grp_num))
|
||||
# self._free_block_grp_indices: Deque[BlockId] = deque(range(self._all_blk_grp_num-1,-1,-1))
|
||||
self._free_block_indices: List[Deque[BlockId]] = [deque(block_ids) for block_ids in block_groups]
|
||||
self._block_grp_indices: List[FrozenSet] = [frozenset(block_ids) for block_ids in block_groups]
|
||||
# self._all_block_indices =[frozenset(block_ids) for block_ids in block_groups]
|
||||
|
||||
# get group id list according to block_id
|
||||
def get_group_id(self, block_id, grp_id_list) -> int:
|
||||
for i in grp_id_list:
|
||||
if block_id in self._block_grp_indices[i]:
|
||||
return i
|
||||
assert False
|
||||
# return -1
|
||||
|
||||
# get group id list according to seq_id
|
||||
def get_group_list(self, seq_id) -> int:
|
||||
"""Get group id list acoording to current seq_id
|
||||
key: seq_id, value: [grp_id, grp_id, ...]
|
||||
"""
|
||||
assert seq_id is not None
|
||||
grp_id_list = []
|
||||
if seq_id in self.seq_mapping:
|
||||
grp_id_list = self.seq_mapping[seq_id]
|
||||
return grp_id_list
|
||||
|
||||
def get_num_free_blocks(self, seq_id: Optional[int] = None) -> int:
|
||||
free_blocks = len(self._free_block_grp_indices) * self._blk_num_per_grp
|
||||
if seq_id is not None:
|
||||
if seq_id in self.seq_mapping:
|
||||
# seq_id is already allocated
|
||||
grp_id_list = self.seq_mapping[seq_id]
|
||||
free_blocks += len(self._free_block_indices[grp_id_list[-1]])
|
||||
else:
|
||||
# new seq_id
|
||||
if len(self.seq_mapping) >= max_seq_num:
|
||||
return 0
|
||||
# means real allocate, only consider free 8K block groups
|
||||
return free_blocks
|
||||
else:
|
||||
# for memory usage analysis / swap
|
||||
# need consider also the free blocks that in 8K blocks groups
|
||||
for _, grp_id_list in self.seq_mapping.items():
|
||||
if len(grp_id_list) > 0:
|
||||
free_blocks += len(self._free_block_indices[grp_id_list[-1]])
|
||||
return free_blocks
|
||||
|
||||
def get_num_total_blocks(self, seq_id: Optional[int] = None) -> int:
|
||||
return len(self._all_block_indices)
|
||||
|
||||
def get_physical_block_id(self, absolute_id: int) -> int:
|
||||
"""Returns the zero-offset block id on certain block allocator
|
||||
given the absolute block id.
|
||||
|
||||
Args:
|
||||
absolute_id (int): The absolute block id for the block
|
||||
in whole allocator.
|
||||
|
||||
Returns:
|
||||
int: The zero-offset block id on certain device.
|
||||
"""
|
||||
return sorted(self._all_block_indices).index(absolute_id)
|
||||
|
||||
@property
|
||||
def refcounter(self):
|
||||
return self._refcounter
|
||||
|
||||
@property
|
||||
def all_block_ids(self) -> FrozenSet[int]:
|
||||
return self._all_block_indices
|
||||
|
||||
def cow_block_if_not_appendable(self, block: Block, seq_id: Optional[int] = None) -> BlockId:
|
||||
"""Performs a copy-on-write operation on the given block if it is not
|
||||
appendable.
|
||||
|
||||
Args:
|
||||
block (Block): The block to check for copy-on-write.
|
||||
|
||||
Returns:
|
||||
BlockId: The block index of the new block if a copy-on-write
|
||||
operation was performed, or the original block index if
|
||||
no copy-on-write was necessary.
|
||||
"""
|
||||
src_block_id = block.block_id
|
||||
assert src_block_id is not None
|
||||
|
||||
if self._cow_tracker.is_appendable(block):
|
||||
return src_block_id
|
||||
|
||||
self._free_block_id(block, seq_id)
|
||||
trg_block_id = self._allocate_block_id()
|
||||
|
||||
self._cow_tracker.record_cow(src_block_id, trg_block_id)
|
||||
|
||||
return trg_block_id
|
||||
|
||||
def clear_copy_on_writes(self, seq_id: Optional[int] = None) -> List[Tuple[BlockId, BlockId]]:
|
||||
"""Returns the copy-on-write source->destination mapping and clears it.
|
||||
|
||||
Returns:
|
||||
List[Tuple[BlockId, BlockId]]: A list mapping source
|
||||
block indices to destination block indices.
|
||||
"""
|
||||
return self._cow_tracker.clear_cows()
|
||||
|
||||
def mark_blocks_as_accessed(self, block_ids: List[int],
|
||||
now: float) -> None:
|
||||
"""Mark blocks as accessed, used in prefix caching.
|
||||
|
||||
Since the naive allocator does not implement prefix caching, we do
|
||||
nothing.
|
||||
"""
|
||||
pass
|
||||
|
||||
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
|
||||
"""Mark blocks as computed, used in prefix caching.
|
||||
|
||||
Since the naive allocator does not implement prefix caching, we do
|
||||
nothing.
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_common_computed_block_ids(
|
||||
self, computed_seq_block_ids: List[List[int]]) -> List[int]:
|
||||
"""Determine blocks that can be skipped in prefill.
|
||||
|
||||
Since the naive allocator does not support prefix caching, always return
|
||||
an empty list.
|
||||
"""
|
||||
return []
|
||||
|
||||
def promote_to_immutable_block(self, block: Block) -> BlockId:
|
||||
raise NotImplementedError("There is no promotion for naive blocks")
|
||||
|
||||
def get_num_full_blocks_touched(self, blocks: List[Block]) -> int:
|
||||
"""Returns the number of full blocks that will be touched by
|
||||
swapping in/out.
|
||||
|
||||
Args:
|
||||
blocks: List of blocks to be swapped.
|
||||
Returns:
|
||||
int: the number of full blocks that will be touched by
|
||||
swapping in/out the given blocks. Non full blocks are ignored
|
||||
when deciding the number of blocks to touch.
|
||||
"""
|
||||
# NOTE: for naive block, we use set to eliminate common blocks among
|
||||
# seqs, also we compare the empty slots in the mutable blocks with
|
||||
# lookahead slots to get the number of unique new block that are
|
||||
# needed.
|
||||
old_block_set = set()
|
||||
for block in blocks:
|
||||
if block.is_full:
|
||||
old_block_set.add(block)
|
||||
return len(old_block_set)
|
||||
|
||||
def swap_out(self, blocks: List[Block], seq_id: Optional[int] = None) -> None:
|
||||
for block in blocks:
|
||||
self._free_block_id(block, seq_id)
|
||||
|
||||
def swap_in(self, blocks: List[Block]) -> None:
|
||||
for block in blocks:
|
||||
# Here we allocate either immutable or mutable block and then
|
||||
# extract its block_id. Note that the block object is released
|
||||
# and the block_id is assigned to "block" to allow reusing the
|
||||
# existing "block" object
|
||||
if block.is_full:
|
||||
tmp_block = self.allocate_immutable_block(
|
||||
prev_block=block.prev_block, token_ids=block.token_ids)
|
||||
else:
|
||||
tmp_block = self.allocate_mutable_block(
|
||||
prev_block=block.prev_block)
|
||||
tmp_block.append_token_ids(block.token_ids)
|
||||
|
||||
block_id = tmp_block.block_id
|
||||
tmp_block.block_id = None
|
||||
self._block_pool.free_block(tmp_block)
|
||||
|
||||
block.block_id = block_id # Assign block_id
|
||||
|
||||
def get_prefix_cache_hit_rate(self) -> float:
|
||||
return -1
|
||||
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
"""No prefix cache for naive block allocator."""
|
||||
return True
|
||||
|
||||
def find_cached_blocks_prefix(self, block_hashes: List[int]) -> List[int]:
|
||||
# Not applicable for naive block allocator.
|
||||
return []
|
||||
|
||||
class NaiveBlock(Block):
|
||||
def append_token_ids(self, token_ids: List[int], seq_id: Optional[int] = None) -> None:
|
||||
"""Appends the given token IDs to the block and performs a
|
||||
copy-on-write if necessary.
|
||||
|
||||
Args:
|
||||
token_ids (Optional[List[int]]): The token IDs to be appended
|
||||
to the block.
|
||||
"""
|
||||
assert seq_id is not None
|
||||
self._append_token_ids_no_cow(token_ids)
|
||||
|
||||
if self._block_id is not None:
|
||||
self._block_id = (self._allocator.cow_block_if_not_appendable(
|
||||
self._cow_target, seq_id))
|
||||
942
vllm_vacc/vllm/core/block/prefix_caching_block.py
Normal file
942
vllm_vacc/vllm/core/block/prefix_caching_block.py
Normal file
@@ -0,0 +1,942 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Token blocks."""
|
||||
import sys
|
||||
from bisect import bisect_left
|
||||
from os.path import commonprefix
|
||||
from typing import (Callable, Dict, FrozenSet, Iterable, List, Optional, Set,
|
||||
Tuple)
|
||||
|
||||
from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker,
|
||||
get_all_blocks_recursively)
|
||||
from vllm.core.block.interfaces import (Block, BlockAllocator, BlockId, Device,
|
||||
DeviceAwareBlockAllocator)
|
||||
from vllm.core.block.naive_block import (BlockPool, NaiveBlock)
|
||||
from vllm_vacc.vllm.core.block.naive_block import NaiveBlockAllocator
|
||||
from vllm.core.block.prefix_caching_block import BlockTracker, assert_prefix_caching_block_or_none
|
||||
|
||||
from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import Sequence
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
PrefixHash = int
|
||||
|
||||
# By default, we init our block access time as _DEFAULT_LAST_ACCESSED_TIME
|
||||
# so that if we find one block is still hold _DEFAULT_LAST_ACCESSED_TIME,
|
||||
# then we know this block hasn't been accessed yet.
|
||||
_DEFAULT_LAST_ACCESSED_TIME = -1
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class PrefixCachingBlockAllocator(BlockAllocator):
|
||||
"""A block allocator that implements prefix caching.
|
||||
|
||||
The PrefixCachingBlockAllocator maintains a cache of blocks based on their
|
||||
content hash. It reuses blocks with the same content hash to avoid redundant
|
||||
memory allocation. The allocator also supports copy-on-write operations.
|
||||
|
||||
Args:
|
||||
num_blocks (int): The total number of blocks to manage.
|
||||
block_size (int): The size of each block in tokens.
|
||||
block_ids(Optional[Iterable[int]], optional): An optional iterable of
|
||||
block IDs. If not provided, block IDs will be assigned sequentially
|
||||
from 0 to num_blocks - 1.
|
||||
"""
|
||||
|
||||
# Note that we use 'None' as a string here instead of None because
|
||||
# as of Python 3.12, hash(None) returns a constant predictable value.
|
||||
# This could possibly make it easier to find and exploit hash
|
||||
# collisions. 'None' as a string will be hashed differently per process,
|
||||
# but consistently within the same process. This is the same as the
|
||||
# behavior of None prior to Python 3.12.
|
||||
_none_hash: int = hash('None')
|
||||
|
||||
# Implements Block.Factory.
|
||||
def __init__(
|
||||
self,
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
block_ids: Optional[Iterable[int]] = None,
|
||||
eviction_policy: EvictionPolicy = EvictionPolicy.LRU,
|
||||
):
|
||||
if block_ids is None:
|
||||
block_ids = range(num_blocks)
|
||||
|
||||
self._block_size = block_size
|
||||
|
||||
# A mapping of prefix hash to block index. All blocks which have a
|
||||
# prefix hash will be in this dict, even if they have refcount 0.
|
||||
self._cached_blocks: Dict[PrefixHash, BlockId] = {}
|
||||
|
||||
# A list of immutable block IDs that have been touched by scheduler
|
||||
# and should be marked as computed after an entire batch of sequences
|
||||
# are scheduled.
|
||||
self._touched_blocks: Set[BlockId] = set()
|
||||
|
||||
# Used to track status of each physical block id
|
||||
self._block_tracker: Dict[BlockId, BlockTracker] = {}
|
||||
for block_id in block_ids:
|
||||
self._block_tracker[block_id] = BlockTracker()
|
||||
|
||||
# Pre-allocate "num_blocks * extra_factor" block objects.
|
||||
# The "* extra_factor" is a buffer to allow more block objects
|
||||
# than physical blocks
|
||||
extra_factor = 4
|
||||
self._block_pool = BlockPool(self._block_size, self._create_block,
|
||||
self, num_blocks * extra_factor)
|
||||
|
||||
# An allocator for blocks that do not have prefix hashes.
|
||||
self._hashless_allocator = NaiveBlockAllocator(
|
||||
create_block=self._create_block, # type: ignore
|
||||
num_blocks=num_blocks,
|
||||
block_size=block_size,
|
||||
block_ids=block_ids,
|
||||
block_pool=self._block_pool, # Share block pool here
|
||||
)
|
||||
|
||||
# Evitor used to maintain how we want to handle those computed blocks
|
||||
# if we find memory pressure is high.
|
||||
self.eviction_policy = eviction_policy
|
||||
self.evictor: Evictor = make_evictor(self.eviction_policy)
|
||||
|
||||
# We share the refcounter between allocators. This allows us to promote
|
||||
# blocks originally allocated in the hashless allocator to immutable
|
||||
# blocks.
|
||||
self._refcounter = self._hashless_allocator.refcounter
|
||||
|
||||
self._cow_tracker = CopyOnWriteTracker(
|
||||
refcounter=self._refcounter.as_readonly())
|
||||
|
||||
self.metric_data = CacheMetricData()
|
||||
|
||||
def _create_block(
|
||||
self,
|
||||
prev_block: Optional[Block],
|
||||
token_ids: List[int],
|
||||
block_size: int,
|
||||
allocator: BlockAllocator,
|
||||
block_id: Optional[int] = None,
|
||||
computed: bool = False,
|
||||
extra_hash: Optional[int] = None,
|
||||
) -> Block:
|
||||
# Bind block to self.
|
||||
allocator = self
|
||||
|
||||
return PrefixCachingBlock(
|
||||
prev_block=prev_block,
|
||||
token_ids=token_ids,
|
||||
block_size=block_size,
|
||||
block_id=block_id,
|
||||
allocator=allocator,
|
||||
computed=computed,
|
||||
extra_hash=extra_hash,
|
||||
)
|
||||
|
||||
def allocate_immutable_block(self,
|
||||
prev_block: Optional[Block],
|
||||
token_ids: List[int],
|
||||
extra_hash: Optional[int] = None,
|
||||
device: Optional[Device] = None,
|
||||
seq_id: Optional[int] = None) -> Block:
|
||||
"""Allocates an immutable block with the given token IDs, reusing cached
|
||||
blocks if possible.
|
||||
|
||||
Args:
|
||||
prev_block (Optional[Block]): The previous block in the sequence.
|
||||
token_ids (List[int]): The token IDs to be stored in the block.
|
||||
|
||||
Returns:
|
||||
Block: The allocated immutable block.
|
||||
"""
|
||||
assert device is None
|
||||
assert_prefix_caching_block_or_none(prev_block)
|
||||
|
||||
# First, try to create a block that points to cached data
|
||||
block = self._block_pool.init_block(prev_block=prev_block,
|
||||
token_ids=token_ids,
|
||||
block_size=self._block_size,
|
||||
physical_block_id=None,
|
||||
extra_hash=extra_hash)
|
||||
assert block.content_hash is not None
|
||||
|
||||
cached_block_id = self._cached_blocks.get(block.content_hash, None)
|
||||
if cached_block_id is not None:
|
||||
self.metric_data.query(hit=True)
|
||||
block.block_id = cached_block_id
|
||||
self._incr_refcount_cached_block(block)
|
||||
return block
|
||||
self.metric_data.query(hit=False)
|
||||
self._block_pool.free_block(block)
|
||||
|
||||
# No cached block => Allocate a new block
|
||||
block = self.allocate_mutable_block(prev_block, extra_hash=extra_hash, seq_id=seq_id)
|
||||
logger.warning(f"Teng seq_id: {seq_id} block: {block.block_id} hash: {block.content_hash}")
|
||||
|
||||
block.append_token_ids(token_ids, seq_id=seq_id)
|
||||
return block
|
||||
|
||||
def allocate_immutable_blocks(
|
||||
self,
|
||||
prev_block: Optional[Block],
|
||||
block_token_ids: List[List[int]],
|
||||
extra_hash: Optional[int] = None,
|
||||
device: Optional[Device] = None,
|
||||
seq_id: Optional[int] = None) -> List[Block]:
|
||||
blocks = []
|
||||
for token_ids in block_token_ids:
|
||||
prev_block = self.allocate_immutable_block(prev_block=prev_block,
|
||||
token_ids=token_ids,
|
||||
device=device,
|
||||
extra_hash=extra_hash,
|
||||
seq_id=seq_id)
|
||||
blocks.append(prev_block)
|
||||
return blocks
|
||||
|
||||
def allocate_mutable_block(self,
|
||||
prev_block: Optional[Block],
|
||||
extra_hash: Optional[int] = None,
|
||||
device: Optional[Device] = None,
|
||||
seq_id: Optional[int] = None) -> Block:
|
||||
"""Allocates a mutable block. If there are no free blocks, this will
|
||||
evict unused cached blocks.
|
||||
|
||||
Args:
|
||||
prev_block (Block): The previous block in the sequence.
|
||||
None is not allowed unlike it is super class.
|
||||
|
||||
Returns:
|
||||
Block: The allocated mutable block.
|
||||
"""
|
||||
assert device is None
|
||||
assert seq_id is not None
|
||||
assert_prefix_caching_block_or_none(prev_block)
|
||||
|
||||
block_id = self._allocate_block_id(seq_id)
|
||||
block = self._block_pool.init_block(prev_block=prev_block,
|
||||
token_ids=[],
|
||||
block_size=self._block_size,
|
||||
physical_block_id=block_id,
|
||||
extra_hash=extra_hash)
|
||||
assert not block.computed
|
||||
assert block.content_hash is None
|
||||
return block
|
||||
|
||||
def _incr_refcount_cached_block(self, block: Block) -> None:
|
||||
# Set this block to be "computed" since it is pointing to a
|
||||
# cached block id (which was already computed)
|
||||
block.computed = True
|
||||
|
||||
block_id = block.block_id
|
||||
assert block_id is not None
|
||||
|
||||
refcount = self._refcounter.incr(block_id)
|
||||
if refcount == 1:
|
||||
# In case a cached block was evicted, restore its tracking
|
||||
if block_id in self.evictor:
|
||||
self.evictor.remove(block_id)
|
||||
|
||||
self._track_block_id(block_id, computed=True)
|
||||
|
||||
def _decr_refcount_cached_block(self, block: Block) -> None:
|
||||
# Ensure this is immutable/cached block
|
||||
assert block.content_hash is not None
|
||||
|
||||
block_id = block.block_id
|
||||
assert block_id is not None
|
||||
|
||||
refcount = self._refcounter.decr(block_id)
|
||||
if refcount > 0:
|
||||
block.block_id = None
|
||||
return
|
||||
else:
|
||||
assert refcount == 0
|
||||
|
||||
# No longer used
|
||||
assert block.content_hash in self._cached_blocks
|
||||
|
||||
# Add the cached block to the evictor
|
||||
# (This keeps the cached block around so it can be reused)
|
||||
self.evictor.add(block_id, block.content_hash, block.num_tokens_total,
|
||||
self._block_tracker[block_id].last_accessed)
|
||||
|
||||
# Stop tracking the block
|
||||
self._untrack_block_id(block_id)
|
||||
|
||||
block.block_id = None
|
||||
|
||||
def _decr_refcount_hashless_block(self, block: Block, seq_id: Optional[int] = None) -> None:
|
||||
block_id = block.block_id
|
||||
assert block_id is not None
|
||||
|
||||
# We may have a fork case where block is shared,
|
||||
# in which case, we cannot remove it from tracking
|
||||
refcount = self._refcounter.get(block_id)
|
||||
if refcount == 1:
|
||||
self._untrack_block_id(block_id)
|
||||
|
||||
# Decrement refcount of the block_id, but do not free the block object
|
||||
# itself (will be handled by the caller)
|
||||
self._hashless_allocator.free(block, keep_block_object=True, seq_id=seq_id)
|
||||
|
||||
def _allocate_block_id(self, seq_id: Optional[int] = None) -> BlockId:
|
||||
"""First tries to allocate a block id from the hashless allocator,
|
||||
and if there are no blocks, then tries to evict an unused cached block.
|
||||
"""
|
||||
assert seq_id is not None
|
||||
hashless_block_id = self._maybe_allocate_hashless_block_id(seq_id=seq_id)
|
||||
if hashless_block_id is not None:
|
||||
return hashless_block_id
|
||||
|
||||
evicted_block_id = self._maybe_allocate_evicted_block_id()
|
||||
if evicted_block_id is not None:
|
||||
return evicted_block_id
|
||||
|
||||
# No block available in hashless allocator, nor in unused cache blocks.
|
||||
raise BlockAllocator.NoFreeBlocksError()
|
||||
|
||||
def _maybe_allocate_hashless_block_id(self, seq_id: Optional[int] = None) -> Optional[BlockId]:
|
||||
try:
|
||||
# Allocate mutable block and extract its block_id
|
||||
block = self._hashless_allocator.allocate_mutable_block(
|
||||
prev_block=None, seq_id=seq_id)
|
||||
block_id = block.block_id
|
||||
self._block_pool.free_block(block)
|
||||
|
||||
self._track_block_id(block_id, computed=False)
|
||||
return block_id
|
||||
except BlockAllocator.NoFreeBlocksError:
|
||||
return None
|
||||
|
||||
def _maybe_allocate_evicted_block_id(self) -> Optional[BlockId]:
|
||||
if self.evictor.num_blocks == 0:
|
||||
return None
|
||||
|
||||
# Here we get an evicted block, which is only added
|
||||
# into evictor if its ref counter is 0
|
||||
# and since its content would be changed, we need
|
||||
# to remove it from _cached_blocks's tracking list
|
||||
block_id, content_hash_to_evict = self.evictor.evict()
|
||||
|
||||
# Sanity checks
|
||||
assert content_hash_to_evict in self._cached_blocks
|
||||
_block_id = self._cached_blocks[content_hash_to_evict]
|
||||
assert self._refcounter.get(_block_id) == 0
|
||||
assert _block_id == block_id
|
||||
|
||||
self._cached_blocks.pop(content_hash_to_evict)
|
||||
|
||||
self._refcounter.incr(block_id)
|
||||
self._track_block_id(block_id, computed=False)
|
||||
|
||||
return block_id
|
||||
|
||||
def _free_block_id(self, block: Block, seq_id: Optional[int] = None) -> None:
|
||||
"""Decrements the refcount of the block. The block may be in two
|
||||
possible states: (1) immutable/cached or (2) mutable/hashless.
|
||||
In the first case, the refcount is decremented directly and the block
|
||||
may be possibly added to the evictor. In other case, hashless
|
||||
allocator free(..) with keep_block_object=True is called to only free
|
||||
the block id (since the block object may be reused by the caller)
|
||||
"""
|
||||
block_id = block.block_id
|
||||
assert block_id is not None, "Freeing unallocated block is undefined"
|
||||
|
||||
if block.content_hash is not None:
|
||||
# Immutable: This type of block is always cached, and we want to
|
||||
# keep it in the evictor for future reuse
|
||||
self._decr_refcount_cached_block(block)
|
||||
else:
|
||||
# Mutable: This type of block is not cached, so we release it
|
||||
# directly to the hashless allocator
|
||||
self._decr_refcount_hashless_block(block, seq_id=seq_id)
|
||||
|
||||
assert block.block_id is None
|
||||
|
||||
def free(self, block: Block, keep_block_object: bool = False, seq_id: Optional[int] = None) -> None:
|
||||
"""Release the block (look at free_block_id(..) docs)
|
||||
"""
|
||||
# Release the physical block index
|
||||
self._free_block_id(block, seq_id=seq_id)
|
||||
|
||||
# Release the block object to the pool
|
||||
if not keep_block_object:
|
||||
self._block_pool.free_block(block)
|
||||
|
||||
def fork(self, last_block: Block) -> List[Block]:
|
||||
"""Creates a new sequence of blocks that shares the same underlying
|
||||
memory as the original sequence.
|
||||
|
||||
Args:
|
||||
last_block (Block): The last block in the original sequence.
|
||||
|
||||
Returns:
|
||||
List[Block]: The new sequence of blocks that shares the same memory
|
||||
as the original sequence.
|
||||
"""
|
||||
source_blocks = get_all_blocks_recursively(last_block)
|
||||
|
||||
forked_blocks: List[Block] = []
|
||||
prev_block = None
|
||||
for block in source_blocks:
|
||||
block_id = block.block_id
|
||||
assert block_id is not None
|
||||
|
||||
refcount = self._refcounter.incr(block_id)
|
||||
assert refcount != 1, "can't fork free'd block_id = {}".format(
|
||||
block_id)
|
||||
|
||||
forked_block = self._block_pool.init_block(
|
||||
prev_block=prev_block,
|
||||
token_ids=block.token_ids,
|
||||
block_size=self._block_size,
|
||||
physical_block_id=block_id,
|
||||
extra_hash=block.extra_hash)
|
||||
|
||||
forked_blocks.append(forked_block)
|
||||
prev_block = forked_blocks[-1]
|
||||
|
||||
return forked_blocks
|
||||
|
||||
def get_num_free_blocks(self, seq_id: Optional[int] = None, device: Optional[Device] = None) -> int:
|
||||
assert device is None
|
||||
# The number of free blocks is the number of hashless free blocks
|
||||
# plus the number of blocks evictor could free from its list.
|
||||
return self._hashless_allocator.get_num_free_blocks(seq_id=seq_id
|
||||
) + self.evictor.num_blocks
|
||||
|
||||
def get_num_total_blocks(self, seq_id: Optional[int] = None) -> int:
|
||||
return self._hashless_allocator.get_num_total_blocks(seq_id=seq_id)
|
||||
|
||||
def get_physical_block_id(self, absolute_id: int) -> int:
|
||||
"""Returns the zero-offset block id on certain block allocator
|
||||
given the absolute block id.
|
||||
|
||||
Args:
|
||||
absolute_id (int): The absolute block id for the block
|
||||
in whole allocator.
|
||||
|
||||
Returns:
|
||||
int: The rzero-offset block id on certain device.
|
||||
"""
|
||||
return sorted(self.all_block_ids).index(absolute_id)
|
||||
|
||||
@property
|
||||
def all_block_ids(self) -> FrozenSet[int]:
|
||||
return self._hashless_allocator.all_block_ids
|
||||
|
||||
def get_prefix_cache_hit_rate(self) -> float:
|
||||
return self.metric_data.get_hit_rate()
|
||||
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
"""Reset prefix cache. This function may be used in RLHF
|
||||
flows to invalid prefix caching after the weights are updated,
|
||||
or used for resetting prefix caching status for benchmarking.
|
||||
|
||||
Returns:
|
||||
bool: True if the prefix cache is successfully reset,
|
||||
False otherwise.
|
||||
"""
|
||||
num_used_blocks = (self.get_num_total_blocks() -
|
||||
self.get_num_free_blocks())
|
||||
if num_used_blocks > 0:
|
||||
logger.warning(
|
||||
"Failed to reset prefix cache because some "
|
||||
"blocks (%d) are not freed yet", num_used_blocks)
|
||||
return False
|
||||
|
||||
# Free all blocks in the evictor.
|
||||
while (block_id :=
|
||||
self._maybe_allocate_evicted_block_id()) is not None:
|
||||
# TODO: Teng
|
||||
self._hashless_allocator.free_block_id(block_id)
|
||||
|
||||
# Should not have any cached blocks because all blocks are evicted.
|
||||
assert not self._cached_blocks
|
||||
|
||||
# Reset the evictor.
|
||||
self.evictor = make_evictor(self.eviction_policy)
|
||||
|
||||
# Reset the block tracker.
|
||||
for block_id in self._block_tracker:
|
||||
self._block_tracker[block_id] = BlockTracker()
|
||||
|
||||
# Reset the metrics.
|
||||
self.metric_data = CacheMetricData()
|
||||
|
||||
logger.info("Successfully reset prefix cache")
|
||||
return True
|
||||
|
||||
def is_block_cached(self, block: Block) -> bool:
|
||||
assert block.content_hash is not None
|
||||
return block.content_hash in self._cached_blocks
|
||||
|
||||
def promote_to_immutable_block(self, block: Block, seq_id: Optional[int] = None) -> BlockId:
|
||||
"""Once a mutable block is full, it can be promoted to an immutable
|
||||
block. This means that its content can be referenced by future blocks
|
||||
having the same prefix.
|
||||
|
||||
Note that if we already have a cached block with the same content, we
|
||||
will replace the newly-promoted block's mapping with the existing cached
|
||||
block id.
|
||||
|
||||
Args:
|
||||
block: The mutable block to be promoted.
|
||||
|
||||
Returns:
|
||||
BlockId: Either the original block index, or the block index of
|
||||
the previously cached block matching the same content.
|
||||
"""
|
||||
# Ensure block can be promoted
|
||||
assert block.content_hash is not None
|
||||
assert block.block_id is not None
|
||||
assert self._refcounter.get(block.block_id) > 0
|
||||
|
||||
if block.content_hash not in self._cached_blocks:
|
||||
# No cached content hash => Set this block as cached.
|
||||
# Note that this block cannot be marked as computed yet
|
||||
# because other sequences in the same batch cannot reuse
|
||||
# this block.
|
||||
self._cached_blocks[block.content_hash] = block.block_id
|
||||
# Mark this block as touched so that it can be marked as
|
||||
# computed after the entire batch of sequences are scheduled.
|
||||
self._touched_blocks.add(block.block_id)
|
||||
return block.block_id
|
||||
|
||||
# Reuse the cached content hash
|
||||
self._decr_refcount_hashless_block(block, seq_id=seq_id)
|
||||
block.block_id = self._cached_blocks[block.content_hash]
|
||||
|
||||
# Increment refcount of the cached block and (possibly) restore
|
||||
# it from the evictor.
|
||||
# Note that in this case, the block is marked as computed
|
||||
self._incr_refcount_cached_block(block)
|
||||
|
||||
return block.block_id
|
||||
|
||||
def cow_block_if_not_appendable(self, block: Block, seq_id: Optional[int] = None) -> BlockId:
|
||||
"""Performs a copy-on-write operation on the given block if it is not
|
||||
appendable.
|
||||
|
||||
Args:
|
||||
block (Block): The block to check for copy-on-write.
|
||||
|
||||
Returns:
|
||||
BlockId: The block index of the new block if a copy-on-write
|
||||
operation was performed, or the original block index if
|
||||
no copy-on-write was necessary.
|
||||
"""
|
||||
src_block_id = block.block_id
|
||||
assert src_block_id is not None
|
||||
|
||||
if self._cow_tracker.is_appendable(block):
|
||||
return src_block_id
|
||||
|
||||
self._free_block_id(block)
|
||||
trg_block_id = self._allocate_block_id()
|
||||
|
||||
self._cow_tracker.record_cow(src_block_id, trg_block_id)
|
||||
|
||||
return trg_block_id
|
||||
|
||||
def clear_copy_on_writes(self) -> List[Tuple[BlockId, BlockId]]:
|
||||
"""Returns the copy-on-write source->destination mapping and clears it.
|
||||
|
||||
Returns:
|
||||
List[Tuple[BlockId, BlockId]]: A list mapping source
|
||||
block indices to destination block indices.
|
||||
"""
|
||||
return self._cow_tracker.clear_cows()
|
||||
|
||||
def mark_blocks_as_accessed(self, block_ids: List[int],
|
||||
now: float) -> None:
|
||||
"""Mark blocks as accessed, used in prefix caching.
|
||||
|
||||
If the block is added into evictor, we need to update corresponding
|
||||
info in evictor's metadata.
|
||||
"""
|
||||
|
||||
for block_id in block_ids:
|
||||
if self._block_tracker[block_id].active:
|
||||
self._block_tracker[block_id].last_accessed = now
|
||||
elif block_id in self.evictor:
|
||||
self.evictor.update(block_id, now)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Mark block as accessed which is not belonged to GPU")
|
||||
|
||||
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
|
||||
# Mark all touched blocks as computed.
|
||||
for block_id in self._touched_blocks:
|
||||
self._block_tracker[block_id].computed = True
|
||||
self._touched_blocks.clear()
|
||||
|
||||
def _track_block_id(self, block_id: Optional[BlockId],
|
||||
computed: bool) -> None:
|
||||
assert block_id is not None
|
||||
self._block_tracker[block_id].enable()
|
||||
self._block_tracker[block_id].computed = computed
|
||||
|
||||
def _untrack_block_id(self, block_id: Optional[BlockId]) -> None:
|
||||
assert block_id is not None
|
||||
self._block_tracker[block_id].disable()
|
||||
|
||||
def block_is_computed(self, block_id: int) -> bool:
|
||||
if self._block_tracker[block_id].active:
|
||||
return self._block_tracker[block_id].computed
|
||||
else:
|
||||
return block_id in self.evictor
|
||||
|
||||
def get_common_computed_block_ids(
|
||||
self, computed_seq_block_ids: List[List[int]]) -> List[int]:
|
||||
"""Return the block ids that are common for a given sequence group.
|
||||
|
||||
Only those blocks that are immutable and already be marked
|
||||
compyted would be taken consideration.
|
||||
"""
|
||||
|
||||
# NOTE We exclude the last block to avoid the case where the entire
|
||||
# prompt is cached. This would cause erroneous behavior in model
|
||||
# runner.
|
||||
|
||||
# It returns a list of int although type annotation says list of string.
|
||||
if len(computed_seq_block_ids) == 1:
|
||||
return computed_seq_block_ids[0]
|
||||
|
||||
return commonprefix([
|
||||
ids for ids in computed_seq_block_ids # type: ignore
|
||||
if ids
|
||||
])
|
||||
|
||||
def get_num_full_blocks_touched(self, blocks: List[Block]) -> int:
|
||||
"""Returns the number of full blocks that will be touched by
|
||||
swapping in/out.
|
||||
|
||||
Args:
|
||||
blocks: List of blocks to be swapped.
|
||||
Returns:
|
||||
int: the number of full blocks that will be touched by
|
||||
swapping in/out the given blocks. Non full blocks are ignored
|
||||
when deciding the number of blocks to touch.
|
||||
"""
|
||||
num_touched_blocks: int = 0
|
||||
for block in blocks:
|
||||
# If the block has a match in the cache and the cached
|
||||
# block is not referenced, then we still count it as a
|
||||
# touched block
|
||||
if block.is_full and (not self.is_block_cached(block) or \
|
||||
(block.content_hash is not None and \
|
||||
self._cached_blocks[block.content_hash] in \
|
||||
self.evictor)):
|
||||
num_touched_blocks += 1
|
||||
return num_touched_blocks
|
||||
|
||||
def swap_out(self, blocks: List[Block]) -> None:
|
||||
"""Execute the swap out actions. Basically just free the
|
||||
given blocks.
|
||||
|
||||
Args:
|
||||
blocks: List of blocks to be swapped out.
|
||||
"""
|
||||
for block in blocks:
|
||||
self._free_block_id(block)
|
||||
|
||||
def swap_in(self, blocks: List[Block]) -> None:
|
||||
"""Execute the swap in actions. Change the block id from
|
||||
old allocator to current allocator for each block to finish
|
||||
the block table update.
|
||||
|
||||
Args:
|
||||
blocks: List of blocks to be swapped in.
|
||||
"""
|
||||
for block in blocks:
|
||||
# Here we allocate either immutable or mutable block and then
|
||||
# extract its block_id. Note that the block object is released
|
||||
# and the block_id is assigned to "block" to allow reusing the
|
||||
# existing "block" object
|
||||
if block.is_full:
|
||||
tmp_block = self.allocate_immutable_block(
|
||||
prev_block=block.prev_block,
|
||||
token_ids=block.token_ids,
|
||||
extra_hash=block.extra_hash)
|
||||
else:
|
||||
tmp_block = self.allocate_mutable_block(
|
||||
prev_block=block.prev_block, extra_hash=block.extra_hash)
|
||||
tmp_block.append_token_ids(block.token_ids)
|
||||
|
||||
block_id = tmp_block.block_id
|
||||
self._block_pool.free_block(tmp_block)
|
||||
|
||||
block.block_id = block_id # Assign block_id
|
||||
|
||||
def find_cached_blocks_prefix(self, block_hashes: List[int]) -> List[int]:
|
||||
"""
|
||||
Given a list of block hashes, return the prefix of the block hashes that
|
||||
are all cached.
|
||||
|
||||
Since a block's block hash includes the hashes of all previous blocks,
|
||||
and we only allocate/deallocate blocks in the entire sequence, so if a
|
||||
block is cached, then all previous blocks are also cached. With this
|
||||
property, we can use binary search to find the prefix of cached blocks.
|
||||
|
||||
Args:
|
||||
block_hashes (List[int]): The list of block hashes.
|
||||
|
||||
Returns:
|
||||
List[int]: The prefix of the `block_hashes` that are cached.
|
||||
"""
|
||||
|
||||
def _block_is_cached(block_hash: PrefixHash) -> bool:
|
||||
if block_hash not in self._cached_blocks:
|
||||
return False
|
||||
|
||||
cached_block_id = self._cached_blocks[block_hash]
|
||||
# We only consider the blocks that are marked as computed.
|
||||
return self.block_is_computed(cached_block_id)
|
||||
|
||||
def _bisect_left(a, x, key: Callable[[PrefixHash], bool]) -> int:
|
||||
|
||||
# python <= 3.10 don't have the key argument
|
||||
if sys.version_info < (3, 10):
|
||||
a = [key(e) for e in a]
|
||||
return bisect_left(a, x)
|
||||
else:
|
||||
return bisect_left(a, x, key=key)
|
||||
|
||||
# Look for the first block that's not cached, and returns the prefix
|
||||
# i.e. blocks that are cached.
|
||||
idx = _bisect_left(block_hashes,
|
||||
True,
|
||||
key=lambda x: not _block_is_cached(x))
|
||||
return block_hashes[:idx]
|
||||
|
||||
|
||||
class PrefixCachingBlock(Block):
|
||||
"""A block implementation that supports prefix caching.
|
||||
|
||||
The PrefixCachingBlock class represents a block of token IDs with prefix
|
||||
caching capabilities. It wraps a NaiveBlock internally and provides
|
||||
additional functionality for content hashing and promoting immutable blocks
|
||||
with the prefix caching allocator.
|
||||
|
||||
Args:
|
||||
prev_block (Optional[PrefixCachingBlock]): The previous block in the
|
||||
sequence.
|
||||
token_ids (List[int]): The initial token IDs to be stored in the block.
|
||||
block_size (int): The maximum number of token IDs that can be stored in
|
||||
the block.
|
||||
allocator (BlockAllocator): The prefix
|
||||
caching block allocator associated with this block.
|
||||
block_id (Optional[int], optional): The physical block index
|
||||
of this block. Defaults to None.
|
||||
extra_hash (Optional[int]): The hash value of additional factors
|
||||
such as adapters that influence the block, apart from the token_ids.
|
||||
"""
|
||||
|
||||
# Note that we use 'None' as a string here instead of None because
|
||||
# as of Python 3.12, hash(None) returns a constant predictable value.
|
||||
# This could possibly make it easier to find and exploit hash
|
||||
# collisions. 'None' as a string will be hashed differently per process,
|
||||
# but consistently within the same process. This is the same as the
|
||||
# behavior of None prior to Python 3.12.
|
||||
_none_hash: int = hash('None')
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prev_block: Optional[Block],
|
||||
token_ids: List[int],
|
||||
block_size: int,
|
||||
allocator: BlockAllocator,
|
||||
block_id: Optional[int] = None,
|
||||
computed: bool = False,
|
||||
extra_hash: Optional[int] = None,
|
||||
):
|
||||
assert isinstance(allocator, PrefixCachingBlockAllocator), (
|
||||
"Currently this class is only tested with "
|
||||
"PrefixCachingBlockAllocator. Got instead allocator = {}".format(
|
||||
allocator))
|
||||
assert_prefix_caching_block_or_none(prev_block)
|
||||
|
||||
self._prev_block = prev_block
|
||||
self._cached_content_hash: Optional[int] = None
|
||||
self._cached_num_tokens_total: int = 0
|
||||
self._allocator = allocator
|
||||
self._last_accessed: float = _DEFAULT_LAST_ACCESSED_TIME
|
||||
self._computed = computed
|
||||
self._extra_hash = extra_hash
|
||||
|
||||
# On the first time, we create the block object, and next we only
|
||||
# reinitialize it
|
||||
if hasattr(self, "_block"):
|
||||
self._block.__init__( # type: ignore[has-type]
|
||||
prev_block=prev_block,
|
||||
token_ids=token_ids,
|
||||
block_size=block_size,
|
||||
block_id=block_id,
|
||||
allocator=self._allocator)
|
||||
else:
|
||||
self._block = NaiveBlock(prev_block=prev_block,
|
||||
token_ids=token_ids,
|
||||
block_size=block_size,
|
||||
block_id=block_id,
|
||||
allocator=self._allocator)
|
||||
|
||||
self._update_num_tokens_total()
|
||||
|
||||
def _update_num_tokens_total(self):
|
||||
"""Incrementally computes the number of tokens that there is
|
||||
till the current block (included)
|
||||
"""
|
||||
res = 0
|
||||
|
||||
# Add all previous blocks
|
||||
if self._prev_block is not None:
|
||||
res += self._prev_block.num_tokens_total
|
||||
|
||||
# Add current block
|
||||
res += len(self.token_ids)
|
||||
|
||||
self._cached_num_tokens_total = res
|
||||
|
||||
@property
|
||||
def computed(self) -> bool:
|
||||
return self._computed
|
||||
|
||||
@computed.setter
|
||||
def computed(self, value) -> None:
|
||||
self._computed = value
|
||||
|
||||
@property
|
||||
def last_accessed(self) -> float:
|
||||
return self._last_accessed
|
||||
|
||||
@last_accessed.setter
|
||||
def last_accessed(self, last_accessed_ts: float):
|
||||
self._last_accessed = last_accessed_ts
|
||||
|
||||
def append_token_ids(self, token_ids: List[int], seq_id: Optional[int] = None) -> None:
|
||||
"""Appends the given token IDs to the block and registers the block as
|
||||
immutable if the block becomes full.
|
||||
|
||||
Args:
|
||||
token_ids (List[int]): The token IDs to be appended to the block.
|
||||
"""
|
||||
# Ensure this is mutable block (not promoted)
|
||||
assert self.content_hash is None
|
||||
assert not self.computed
|
||||
assert seq_id is not None
|
||||
|
||||
if len(token_ids) == 0:
|
||||
return
|
||||
|
||||
# Ensure there are input tokens
|
||||
assert token_ids, "Got token_ids = {}".format(token_ids)
|
||||
|
||||
# Naive block handles CoW.
|
||||
self._block.append_token_ids(token_ids, seq_id=seq_id)
|
||||
self._update_num_tokens_total()
|
||||
|
||||
# If the content hash is present, then the block can be made immutable.
|
||||
# Register ourselves with the allocator, potentially replacing the
|
||||
# physical block index.
|
||||
if self.content_hash is not None:
|
||||
self.block_id = self._allocator.promote_to_immutable_block(self, seq_id=seq_id)
|
||||
|
||||
@property
|
||||
def block_id(self) -> Optional[int]:
|
||||
return self._block.block_id
|
||||
|
||||
@block_id.setter
|
||||
def block_id(self, value) -> None:
|
||||
self._block.block_id = value
|
||||
|
||||
@property
|
||||
def is_full(self) -> bool:
|
||||
return self._block.is_full
|
||||
|
||||
@property
|
||||
def num_empty_slots(self) -> int:
|
||||
return self._block.num_empty_slots
|
||||
|
||||
@property
|
||||
def num_tokens_total(self) -> int:
|
||||
return self._cached_num_tokens_total
|
||||
|
||||
@property
|
||||
def block_size(self) -> int:
|
||||
return self._block.block_size
|
||||
|
||||
@property
|
||||
def token_ids(self) -> List[int]:
|
||||
return self._block.token_ids
|
||||
|
||||
@property
|
||||
def prev_block(self) -> Optional[Block]:
|
||||
return self._prev_block
|
||||
|
||||
@property
|
||||
def extra_hash(self) -> Optional[int]:
|
||||
return self._extra_hash
|
||||
|
||||
@property
|
||||
def content_hash(self) -> Optional[int]:
|
||||
"""Return the content-based hash of the current block, or None if it is
|
||||
not yet defined.
|
||||
|
||||
For the content-based hash to be defined, the current block must be
|
||||
full.
|
||||
"""
|
||||
# If the hash is already computed, return it.
|
||||
if self._cached_content_hash is not None:
|
||||
return self._cached_content_hash
|
||||
|
||||
# We cannot compute a hash for the current block because it is not full.
|
||||
if not self.is_full:
|
||||
return None
|
||||
|
||||
is_first_block = self._prev_block is None
|
||||
prev_block_hash = (
|
||||
self._none_hash if is_first_block else
|
||||
self._prev_block.content_hash # type: ignore
|
||||
)
|
||||
|
||||
# Previous block exists but does not yet have a hash.
|
||||
# Return no hash in this case.
|
||||
if prev_block_hash == self._none_hash and not is_first_block:
|
||||
return None
|
||||
|
||||
self._cached_content_hash = PrefixCachingBlock.hash_block_tokens(
|
||||
is_first_block,
|
||||
prev_block_hash,
|
||||
cur_block_token_ids=self.token_ids,
|
||||
extra_hash=self._extra_hash)
|
||||
return self._cached_content_hash
|
||||
|
||||
@classmethod
|
||||
def hash_block_tokens(cls,
|
||||
is_first_block: bool,
|
||||
prev_block_hash: Optional[int],
|
||||
cur_block_token_ids: List[int],
|
||||
extra_hash: Optional[int] = None) -> int:
|
||||
"""Computes a hash value corresponding to the contents of a block and
|
||||
the contents of the preceding block(s). The hash value is used for
|
||||
prefix caching.
|
||||
|
||||
Parameters:
|
||||
- is_first_block (bool): A flag indicating if the block is the first in
|
||||
the sequence.
|
||||
- prev_block_hash (Optional[int]): The hash of the previous block. None
|
||||
if this is the first block.
|
||||
- cur_block_token_ids (List[int]): A list of token ids in the current
|
||||
block. The current block is assumed to be full.
|
||||
- extra_hash (Optional[int]): The hash value of additional factors
|
||||
such as adapters that influence the block, apart from the token_ids.
|
||||
|
||||
Returns:
|
||||
- int: The computed hash value for the block.
|
||||
"""
|
||||
if is_first_block and prev_block_hash is None:
|
||||
prev_block_hash = cls._none_hash
|
||||
return hash((is_first_block, prev_block_hash, *cur_block_token_ids,
|
||||
extra_hash))
|
||||
575
vllm_vacc/vllm/core/block_manager.py
Normal file
575
vllm_vacc/vllm/core/block_manager.py
Normal file
@@ -0,0 +1,575 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""A block manager that manages token blocks."""
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Tuple
|
||||
|
||||
from vllm.core.block.block_table import BlockTable
|
||||
from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator
|
||||
from vllm.core.block.interfaces import Block
|
||||
from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker,
|
||||
LastAccessBlocksTracker)
|
||||
from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec
|
||||
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
|
||||
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
|
||||
from vllm.utils import Device
|
||||
import os
|
||||
SeqId = int
|
||||
EncoderSeqId = str
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from vllm_vacc.vllm.model_executor.models.vars import LLM_MAX_PREFILL_SEQ_LEN
|
||||
|
||||
max_seq_num = int(os.getenv("MAX_SEQ_NUM", 4))
|
||||
if max_seq_num not in [1, 2, 4]:
|
||||
max_seq_num = 4
|
||||
|
||||
from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size
|
||||
|
||||
class SelfAttnBlockSpaceManager(BlockSpaceManager):
|
||||
"""BlockSpaceManager which manages the allocation of KV cache.
|
||||
|
||||
It owns responsibility for allocation, swapping, allocating memory for
|
||||
autoregressively-generated tokens, and other advanced features such as
|
||||
prefix caching, forking/copy-on-write, and sliding-window memory allocation.
|
||||
|
||||
This class implements the design described in
|
||||
https://github.com/vllm-project/vllm/pull/3492.
|
||||
|
||||
Lookahead slots
|
||||
The block manager has the notion of a "lookahead slot". These are slots
|
||||
in the KV cache that are allocated for a sequence. Unlike the other
|
||||
allocated slots, the content of these slots is undefined -- the worker
|
||||
may use the memory allocations in any way.
|
||||
|
||||
In practice, a worker could use these lookahead slots to run multiple
|
||||
forward passes for a single scheduler invocation. Each successive
|
||||
forward pass would write KV activations to the corresponding lookahead
|
||||
slot. This allows low inter-token latency use-cases, where the overhead
|
||||
of continuous batching scheduling is amortized over >1 generated tokens.
|
||||
|
||||
Speculative decoding uses lookahead slots to store KV activations of
|
||||
proposal tokens.
|
||||
|
||||
See https://github.com/vllm-project/vllm/pull/3250 for more information
|
||||
on lookahead scheduling.
|
||||
|
||||
Args:
|
||||
block_size (int): The size of each memory block.
|
||||
num_gpu_blocks (int): The number of memory blocks allocated on GPU.
|
||||
num_cpu_blocks (int): The number of memory blocks allocated on CPU.
|
||||
watermark (float, optional): The threshold used for memory swapping.
|
||||
Defaults to 0.01.
|
||||
sliding_window (Optional[int], optional): The size of the sliding
|
||||
window. Defaults to None.
|
||||
enable_caching (bool, optional): Flag indicating whether caching is
|
||||
enabled. Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_size: int,
|
||||
num_gpu_blocks: int,
|
||||
num_cpu_blocks: int,
|
||||
watermark: float = 0.01,
|
||||
sliding_window: Optional[int] = None,
|
||||
enable_caching: bool = False,
|
||||
) -> None:
|
||||
self.block_size = block_size
|
||||
self.num_total_gpu_blocks = num_gpu_blocks
|
||||
self.num_total_cpu_blocks = num_cpu_blocks
|
||||
self.per_gpu_blocks = num_gpu_blocks // max_seq_num
|
||||
|
||||
self.sliding_window = sliding_window
|
||||
# max_block_sliding_window is the max number of blocks that need to be
|
||||
# allocated
|
||||
self.max_block_sliding_window = None
|
||||
if sliding_window is not None:
|
||||
# +1 here because // rounds down
|
||||
num_blocks = sliding_window // block_size + 1
|
||||
# +1 here because the last block may not be full,
|
||||
# and so the sequence stretches one more block at the beginning
|
||||
# For example, if sliding_window is 3 and block_size is 4,
|
||||
# we may need 2 blocks when the second block only holds 1 token.
|
||||
self.max_block_sliding_window = num_blocks + 1
|
||||
|
||||
self.watermark = watermark
|
||||
assert watermark >= 0.0
|
||||
|
||||
self.enable_caching = enable_caching
|
||||
|
||||
# self.watermark_blocks = 1 # for test
|
||||
self.watermark_blocks = int(watermark * self.per_gpu_blocks)
|
||||
|
||||
self.block_allocator = CpuGpuBlockAllocator.create(
|
||||
allocator_type="prefix_caching" if enable_caching else "naive",
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=num_cpu_blocks,
|
||||
block_size=block_size,
|
||||
)
|
||||
|
||||
self.block_tables: Dict[SeqId, BlockTable] = {}
|
||||
self.cross_block_tables: Dict[EncoderSeqId, BlockTable] = {}
|
||||
|
||||
self._computed_blocks_tracker = ComputedBlocksTracker(
|
||||
self.block_allocator, self.block_size, self.enable_caching)
|
||||
self._last_access_blocks_tracker = LastAccessBlocksTracker(
|
||||
self.block_allocator)
|
||||
|
||||
def can_allocate(self,
|
||||
seq_group: SequenceGroup,
|
||||
num_lookahead_slots: int = 0) -> AllocStatus:
|
||||
# FIXME(woosuk): Here we assume that all sequences in the group share
|
||||
# the same prompt. This may not be true for preempted sequences.
|
||||
|
||||
check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group)
|
||||
|
||||
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
|
||||
num_required_blocks = BlockTable.get_num_required_blocks(
|
||||
seq.get_token_ids(),
|
||||
block_size=self.block_size,
|
||||
num_lookahead_slots=num_lookahead_slots,
|
||||
)
|
||||
|
||||
if seq_group.is_encoder_decoder():
|
||||
encoder_seq = seq_group.get_encoder_seq()
|
||||
assert encoder_seq is not None
|
||||
num_required_blocks += BlockTable.get_num_required_blocks(
|
||||
encoder_seq.get_token_ids(),
|
||||
block_size=self.block_size,
|
||||
)
|
||||
|
||||
if self.max_block_sliding_window is not None:
|
||||
num_required_blocks = min(num_required_blocks,
|
||||
self.max_block_sliding_window)
|
||||
|
||||
# TODO
|
||||
# limitations to be removed later
|
||||
required_size = num_required_blocks * self.block_size
|
||||
if required_size > LLM_MAX_PREFILL_SEQ_LEN:
|
||||
logging.warning(
|
||||
f"This model's maximum input seq length limit is "
|
||||
f"{LLM_MAX_PREFILL_SEQ_LEN} tokens. However, you requested "
|
||||
f"({required_size} in the input messages, "
|
||||
f"Please reduce the length of the input messages.")
|
||||
return AllocStatus.NEVER
|
||||
|
||||
# Use watermark to avoid frequent cache eviction.
|
||||
# NOTE:
|
||||
# num of the gpu blocks for each seq_id might not be the same
|
||||
# since each seq can use different blk group number
|
||||
|
||||
total_gpu_blocks = self.block_allocator.get_num_total_blocks(device=Device.GPU, seq_id=seq.seq_id)
|
||||
if (total_gpu_blocks
|
||||
- num_required_blocks < self.watermark_blocks):
|
||||
self.block_allocator.get_num_total_blocks(device=Device.GPU, seq_id=seq.seq_id)
|
||||
return AllocStatus.NEVER
|
||||
|
||||
# NOTE: num_required_blocks should be up aligned to 8K beforce compare
|
||||
block_num = env_blk_grp_size // self.block_size
|
||||
if num_required_blocks % block_num: # align
|
||||
num_required_blocks = (num_required_blocks // block_num + 1) * block_num
|
||||
|
||||
if total_gpu_blocks % block_num:
|
||||
total_gpu_blocks = (total_gpu_blocks // block_num) * block_num
|
||||
# Use the aligned memory size to decide whether to reject the request.
|
||||
if total_gpu_blocks - num_required_blocks < self.watermark_blocks:
|
||||
logging.warning("gpu memory may not enough, please try shorter sequence")
|
||||
return AllocStatus.NEVER
|
||||
|
||||
num_free_gpu_blocks = self.block_allocator.get_num_free_blocks(
|
||||
device=Device.GPU, seq_id=seq.seq_id)
|
||||
# logging.warning(f"free blocks: {num_free_gpu_blocks} required: {num_required_blocks} watermark: {self.watermark_blocks}")
|
||||
|
||||
if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks:
|
||||
return AllocStatus.OK
|
||||
else:
|
||||
# logging.warning(f"free_blocks:{num_free_gpu_blocks} "
|
||||
# f"required_blocks:{num_required_blocks} "
|
||||
# f"watermark:{self.watermark_blocks} "
|
||||
# f"allocate seq:{seq.seq_id} later")
|
||||
return AllocStatus.LATER
|
||||
|
||||
|
||||
def _allocate_sequence(self, seq: Sequence) -> BlockTable:
|
||||
block_table = BlockTable(
|
||||
block_size=self.block_size,
|
||||
block_allocator=self.block_allocator,
|
||||
max_block_sliding_window=self.max_block_sliding_window,
|
||||
)
|
||||
if seq.get_token_ids():
|
||||
# NOTE: If there are any factors affecting the block besides
|
||||
# token_ids, they should be added as input to extra_hash.
|
||||
extra_hash = seq.extra_hash()
|
||||
|
||||
# Add blocks to the block table only if the sequence is non empty.
|
||||
block_table.allocate(token_ids=seq.get_token_ids(),
|
||||
extra_hash=extra_hash,
|
||||
seq_id=seq.seq_id)
|
||||
|
||||
return block_table
|
||||
|
||||
def allocate(self, seq_group: SequenceGroup) -> None:
|
||||
|
||||
# Allocate self-attention block tables for decoder sequences
|
||||
waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
|
||||
assert not (set(seq.seq_id for seq in waiting_seqs)
|
||||
& self.block_tables.keys()), "block table already exists"
|
||||
|
||||
# NOTE: Here we assume that all sequences in the group have the same
|
||||
# prompt.
|
||||
seq = waiting_seqs[0]
|
||||
block_table: BlockTable = self._allocate_sequence(seq)
|
||||
self.block_tables[seq.seq_id] = block_table
|
||||
|
||||
# Track seq
|
||||
self._last_access_blocks_tracker.add_seq(seq.seq_id)
|
||||
|
||||
# Assign the block table for each sequence.
|
||||
for seq in waiting_seqs[1:]:
|
||||
self.block_tables[seq.seq_id] = block_table.fork()
|
||||
|
||||
# Track seq
|
||||
self._last_access_blocks_tracker.add_seq(seq.seq_id)
|
||||
|
||||
# Allocate cross-attention block table for encoder sequence
|
||||
#
|
||||
# NOTE: Here we assume that all sequences in the group have the same
|
||||
# encoder prompt.
|
||||
request_id = seq_group.request_id
|
||||
|
||||
assert (request_id
|
||||
not in self.cross_block_tables), \
|
||||
"block table already exists"
|
||||
|
||||
check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group)
|
||||
|
||||
if seq_group.is_encoder_decoder():
|
||||
encoder_seq = seq_group.get_encoder_seq()
|
||||
assert encoder_seq is not None
|
||||
block_table = self._allocate_sequence(encoder_seq)
|
||||
self.cross_block_tables[request_id] = block_table
|
||||
|
||||
def can_append_slots(self, seq_group: SequenceGroup,
|
||||
num_lookahead_slots: int) -> bool:
|
||||
"""Determine if there is enough space in the GPU KV cache to continue
|
||||
generation of the specified sequence group.
|
||||
|
||||
We use a worst-case heuristic: assume each touched block will require a
|
||||
new allocation (either via CoW or new block). We can append slots if the
|
||||
number of touched blocks is less than the number of free blocks.
|
||||
|
||||
"Lookahead slots" are slots that are allocated in addition to the slots
|
||||
for known tokens. The contents of the lookahead slots are not defined.
|
||||
This is used by speculative decoding when speculating future tokens.
|
||||
"""
|
||||
|
||||
num_touched_blocks = 0
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
|
||||
num_touched_blocks += (
|
||||
block_table.get_num_blocks_touched_by_append_slots(
|
||||
token_ids=block_table.get_unseen_token_ids(
|
||||
seq.get_token_ids()),
|
||||
num_lookahead_slots=num_lookahead_slots,
|
||||
))
|
||||
|
||||
num_free_gpu_blocks = self.block_allocator.get_num_free_blocks(
|
||||
Device.GPU, seq_id=seq.seq_id)
|
||||
# NOTE: if False, trigger RECOMPUTE
|
||||
return num_touched_blocks <= num_free_gpu_blocks
|
||||
|
||||
def append_slots(
|
||||
self,
|
||||
seq: Sequence,
|
||||
num_lookahead_slots: int,
|
||||
) -> List[Tuple[int, int]]:
|
||||
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
|
||||
block_table.append_token_ids(
|
||||
token_ids=block_table.get_unseen_token_ids(seq.get_token_ids()),
|
||||
num_lookahead_slots=num_lookahead_slots,
|
||||
num_computed_slots=seq.data.get_num_computed_tokens(),
|
||||
extra_hash=seq.extra_hash(),
|
||||
seq_id = seq.seq_id
|
||||
)
|
||||
# Return any new copy-on-writes.
|
||||
new_cows = self.block_allocator.clear_copy_on_writes(seq.seq_id)
|
||||
return new_cows
|
||||
|
||||
def free(self, seq: Sequence) -> None:
|
||||
seq_id = seq.seq_id
|
||||
|
||||
if seq_id not in self.block_tables:
|
||||
# Already freed or haven't been scheduled yet.
|
||||
return
|
||||
|
||||
# Update seq block ids with the latest access time
|
||||
self._last_access_blocks_tracker.update_seq_blocks_last_access(
|
||||
seq_id, self.block_tables[seq.seq_id].physical_block_ids)
|
||||
|
||||
# Untrack seq
|
||||
self._last_access_blocks_tracker.remove_seq(seq_id)
|
||||
self._computed_blocks_tracker.remove_seq(seq_id)
|
||||
|
||||
# Free table/blocks
|
||||
self.block_tables[seq_id].free(seq_id)
|
||||
del self.block_tables[seq_id]
|
||||
|
||||
def remove_seq_from_computed_blocks_tracker(self, seq: Sequence) -> None:
|
||||
seq_id = seq.seq_id
|
||||
self._computed_blocks_tracker.remove_seq(seq_id)
|
||||
|
||||
def free_cross(self, seq_group: SequenceGroup) -> None:
|
||||
request_id = seq_group.request_id
|
||||
if request_id not in self.cross_block_tables:
|
||||
# Already freed or hasn't been scheduled yet.
|
||||
return
|
||||
self.cross_block_tables[request_id].free()
|
||||
del self.cross_block_tables[request_id]
|
||||
|
||||
def get_block_table(self, seq: Sequence) -> List[int]:
|
||||
block_ids = self.block_tables[seq.seq_id].physical_block_ids
|
||||
return block_ids # type: ignore
|
||||
|
||||
def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]:
|
||||
request_id = seq_group.request_id
|
||||
assert request_id in self.cross_block_tables
|
||||
block_ids = self.cross_block_tables[request_id].physical_block_ids
|
||||
assert all(b is not None for b in block_ids)
|
||||
return block_ids # type: ignore
|
||||
|
||||
def access_all_blocks_in_seq(self, seq: Sequence, now: float):
|
||||
if self.enable_caching:
|
||||
# Record the latest access time for the sequence. The actual update
|
||||
# of the block ids is deferred to the sequence free(..) call, since
|
||||
# only during freeing of block ids, the blocks are actually added to
|
||||
# the evictor (which is when the most updated time is required)
|
||||
# (This avoids expensive calls to mark_blocks_as_accessed(..))
|
||||
self._last_access_blocks_tracker.update_last_access(
|
||||
seq.seq_id, now)
|
||||
|
||||
def mark_blocks_as_computed(self, seq_group: SequenceGroup,
|
||||
token_chunk_size: int):
|
||||
# If prefix caching is enabled, mark immutable blocks as computed
|
||||
# right after they have been scheduled (for prefill). This assumes
|
||||
# the scheduler is synchronous so blocks are actually computed when
|
||||
# scheduling the next batch.
|
||||
self.block_allocator.mark_blocks_as_computed([])
|
||||
|
||||
def get_common_computed_block_ids(
|
||||
self, seqs: List[Sequence]) -> GenericSequence[int]:
|
||||
"""Determine which blocks for which we skip prefill.
|
||||
|
||||
With prefix caching we can skip prefill for previously-generated blocks.
|
||||
Currently, the attention implementation only supports skipping cached
|
||||
blocks if they are a contiguous prefix of cached blocks.
|
||||
|
||||
This method determines which blocks can be safely skipped for all
|
||||
sequences in the sequence group.
|
||||
"""
|
||||
computed_seq_block_ids = []
|
||||
for seq in seqs:
|
||||
all_blocks = self.block_tables[seq.seq_id].physical_block_ids
|
||||
num_cached_tokens = (
|
||||
self._computed_blocks_tracker.get_num_cached_tokens(seq))
|
||||
assert num_cached_tokens % self.block_size == 0
|
||||
num_cached_blocks = num_cached_tokens // self.block_size
|
||||
computed_block_ids = all_blocks[:num_cached_blocks]
|
||||
computed_seq_block_ids.append(computed_block_ids)
|
||||
|
||||
# NOTE(sang): This assumes seq_block_ids doesn't contain any None.
|
||||
return self.block_allocator.get_common_computed_block_ids(
|
||||
computed_seq_block_ids) # type: ignore
|
||||
|
||||
def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
|
||||
if parent_seq.seq_id not in self.block_tables:
|
||||
# Parent sequence has either been freed or never existed.
|
||||
return
|
||||
src_block_table = self.block_tables[parent_seq.seq_id]
|
||||
self.block_tables[child_seq.seq_id] = src_block_table.fork()
|
||||
|
||||
# Track child seq
|
||||
self._last_access_blocks_tracker.add_seq(child_seq.seq_id)
|
||||
|
||||
def can_swap_in(self, seq_group: SequenceGroup,
|
||||
num_lookahead_slots: int) -> AllocStatus:
|
||||
"""Returns the AllocStatus for the given sequence_group
|
||||
with num_lookahead_slots.
|
||||
|
||||
Args:
|
||||
sequence_group (SequenceGroup): The sequence group to swap in.
|
||||
num_lookahead_slots (int): Number of lookahead slots used in
|
||||
speculative decoding, default to 0.
|
||||
|
||||
Returns:
|
||||
AllocStatus: The AllocStatus for the given sequence group.
|
||||
"""
|
||||
return self._can_swap(seq_group, Device.GPU, SequenceStatus.SWAPPED,
|
||||
num_lookahead_slots)
|
||||
|
||||
def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
|
||||
"""Returns the block id mapping (from CPU to GPU) generated by
|
||||
swapping in the given seq_group with num_lookahead_slots.
|
||||
|
||||
Args:
|
||||
seq_group (SequenceGroup): The sequence group to swap in.
|
||||
|
||||
Returns:
|
||||
List[Tuple[int, int]]: The mapping of swapping block from CPU
|
||||
to GPU.
|
||||
"""
|
||||
physical_block_id_mapping = []
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
|
||||
blocks = self.block_tables[seq.seq_id].blocks
|
||||
if len(blocks) == 0:
|
||||
continue
|
||||
|
||||
seq_swap_mapping = self.block_allocator.swap(blocks=blocks,
|
||||
src_device=Device.CPU,
|
||||
dst_device=Device.GPU)
|
||||
|
||||
# Refresh the block ids of the table (post-swap)
|
||||
self.block_tables[seq.seq_id].update(blocks)
|
||||
|
||||
seq_physical_block_id_mapping = {
|
||||
self.block_allocator.get_physical_block_id(
|
||||
Device.CPU, cpu_block_id):
|
||||
self.block_allocator.get_physical_block_id(
|
||||
Device.GPU, gpu_block_id)
|
||||
for cpu_block_id, gpu_block_id in seq_swap_mapping.items()
|
||||
}
|
||||
|
||||
physical_block_id_mapping.extend(
|
||||
list(seq_physical_block_id_mapping.items()))
|
||||
|
||||
return physical_block_id_mapping
|
||||
|
||||
def can_swap_out(self, seq_group: SequenceGroup) -> bool:
|
||||
"""Returns whether we can swap out the given sequence_group
|
||||
with num_lookahead_slots.
|
||||
|
||||
Args:
|
||||
seq_group (SequenceGroup): The sequence group to swap out.
|
||||
num_lookahead_slots (int): Number of lookahead slots used in
|
||||
speculative decoding, default to 0.
|
||||
|
||||
Returns:
|
||||
bool: Whether it's possible to swap out current sequence group.
|
||||
"""
|
||||
alloc_status = self._can_swap(seq_group, Device.CPU,
|
||||
SequenceStatus.RUNNING)
|
||||
return alloc_status == AllocStatus.OK
|
||||
|
||||
def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
|
||||
"""Returns the block id mapping (from GPU to CPU) generated by
|
||||
swapping out the given sequence_group with num_lookahead_slots.
|
||||
|
||||
Args:
|
||||
sequence_group (SequenceGroup): The sequence group to swap out.
|
||||
|
||||
Returns:
|
||||
List[Tuple[int, int]]: The mapping of swapping block from
|
||||
GPU to CPU.
|
||||
"""
|
||||
physical_block_id_mapping = []
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
blocks = self.block_tables[seq.seq_id].blocks
|
||||
if len(blocks) == 0:
|
||||
continue
|
||||
|
||||
seq_swap_mapping = self.block_allocator.swap(blocks=blocks,
|
||||
src_device=Device.GPU,
|
||||
dst_device=Device.CPU)
|
||||
|
||||
# Refresh the block ids of the table (post-swap)
|
||||
self.block_tables[seq.seq_id].update(blocks)
|
||||
|
||||
seq_physical_block_id_mapping = {
|
||||
self.block_allocator.get_physical_block_id(
|
||||
Device.GPU, gpu_block_id):
|
||||
self.block_allocator.get_physical_block_id(
|
||||
Device.CPU, cpu_block_id)
|
||||
for gpu_block_id, cpu_block_id in seq_swap_mapping.items()
|
||||
}
|
||||
|
||||
physical_block_id_mapping.extend(
|
||||
list(seq_physical_block_id_mapping.items()))
|
||||
|
||||
return physical_block_id_mapping
|
||||
|
||||
def get_num_free_gpu_blocks(self) -> int:
|
||||
return self.block_allocator.get_num_free_blocks(Device.GPU)
|
||||
|
||||
def get_num_free_cpu_blocks(self) -> int:
|
||||
return self.block_allocator.get_num_free_blocks(Device.CPU)
|
||||
|
||||
def get_prefix_cache_hit_rate(self, device: Device) -> float:
|
||||
return self.block_allocator.get_prefix_cache_hit_rate(device)
|
||||
|
||||
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
|
||||
return self.block_allocator.reset_prefix_cache(device)
|
||||
|
||||
def _can_swap(self,
|
||||
seq_group: SequenceGroup,
|
||||
device: Device,
|
||||
status: SequenceStatus,
|
||||
num_lookahead_slots: int = 0) -> AllocStatus:
|
||||
"""Returns the AllocStatus for swapping in/out the given sequence_group
|
||||
on to the 'device'.
|
||||
|
||||
Args:
|
||||
sequence_group (SequenceGroup): The sequence group to swap in/out.
|
||||
device (Device): device to swap the 'seq_group' on.
|
||||
status (SequenceStatus): The status of sequence which is needed
|
||||
for action. RUNNING for swap out and SWAPPED for swap in
|
||||
num_lookahead_slots (int): Number of lookahead slots used in
|
||||
speculative decoding, default to 0.
|
||||
|
||||
Returns:
|
||||
AllocStatus: The AllocStatus for swapping in/out the given
|
||||
sequence_group on to the 'device'.
|
||||
"""
|
||||
# First determine the number of blocks that will be touched by this
|
||||
# swap. Then verify if there are available blocks in the device
|
||||
# to perform the swap.
|
||||
num_blocks_touched = 0
|
||||
blocks: List[Block] = []
|
||||
for seq in seq_group.get_seqs(status=status):
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
if block_table.blocks is not None:
|
||||
# Compute the number blocks to touch for the tokens to be
|
||||
# appended. This does NOT include the full blocks that need
|
||||
# to be touched for the swap.
|
||||
num_blocks_touched += \
|
||||
block_table.get_num_blocks_touched_by_append_slots(
|
||||
block_table.get_unseen_token_ids(seq.get_token_ids()),
|
||||
num_lookahead_slots=num_lookahead_slots)
|
||||
blocks.extend(block_table.blocks)
|
||||
# Compute the number of full blocks to touch and add it to the
|
||||
# existing count of blocks to touch.
|
||||
num_blocks_touched += self.block_allocator.get_num_full_blocks_touched(
|
||||
blocks, device=device)
|
||||
|
||||
watermark_blocks = 0
|
||||
if device == Device.GPU:
|
||||
watermark_blocks = self.watermark_blocks
|
||||
|
||||
if self.block_allocator.get_num_total_blocks(
|
||||
device) < num_blocks_touched:
|
||||
return AllocStatus.NEVER
|
||||
elif self.block_allocator.get_num_free_blocks(
|
||||
device) - num_blocks_touched >= watermark_blocks:
|
||||
return AllocStatus.OK
|
||||
else:
|
||||
return AllocStatus.LATER
|
||||
|
||||
def get_num_cached_tokens(self, seq: Sequence) -> int:
|
||||
"""Get the number of tokens in blocks that are already computed and
|
||||
cached in the block manager for the sequence.
|
||||
"""
|
||||
return self._computed_blocks_tracker.get_num_cached_tokens(seq)
|
||||
0
vllm_vacc/vllm/distributed/__init__.py
Normal file
0
vllm_vacc/vllm/distributed/__init__.py
Normal file
BIN
vllm_vacc/vllm/distributed/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/distributed/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
23
vllm_vacc/vllm/distributed/communication_op.py
Normal file
23
vllm_vacc/vllm/distributed/communication_op.py
Normal file
@@ -0,0 +1,23 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
def tensor_model_parallel_all_reduce_with_odsp(input_: torch.Tensor) -> torch.Tensor:
|
||||
"""All-reduce the input tensor across model parallel group."""
|
||||
from vllm.distributed import get_tp_group
|
||||
try:
|
||||
total_bytes = input_.numel() * input_.element_size() * get_tp_group().world_size
|
||||
# only support 4M now
|
||||
if total_bytes < 4194304:
|
||||
from torch_vacc.vacc import all_reduce
|
||||
return all_reduce(input_,
|
||||
get_tp_group().rank_in_group,
|
||||
get_tp_group().world_size,
|
||||
get_tp_group().group_id,
|
||||
dev_info = get_tp_group().rank_device_infos)
|
||||
except Exception as e:
|
||||
print("all_reduce by DSP run Fail, now use vccl-ops", e, input_.shape, input_.dtype)
|
||||
|
||||
return get_tp_group().all_reduce(input_)
|
||||
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,24 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
def all_gather_into_tensor(self, input_: torch.Tensor, dim: int = -1, output_tensor: torch.Tensor = None) -> torch.Tensor:
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
input_size = input_.size()
|
||||
# NOTE: we have to use concat-style all-gather here,
|
||||
# stack-style all-gather has compatibility issues with
|
||||
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
|
||||
output_size = (input_size[0] * self.world_size, ) + input_size[1:]
|
||||
# Allocate output tensor.
|
||||
# [N,] => [N*world_size], 1D Tensor
|
||||
if output_tensor is None:
|
||||
output_tensor = torch.empty(output_size,
|
||||
dtype=input_.dtype,
|
||||
device=input_.device)
|
||||
# print("o tensor is:", output_tensor.shape, "i tensor is:", input_.shape, input_size)
|
||||
# All-gather.
|
||||
dist.all_gather_into_tensor(output_tensor,
|
||||
input_,
|
||||
group=self.device_group)
|
||||
return output_tensor
|
||||
485
vllm_vacc/vllm/distributed/parallel_state.py
Normal file
485
vllm_vacc/vllm/distributed/parallel_state.py
Normal file
@@ -0,0 +1,485 @@
|
||||
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import supports_custom_op
|
||||
|
||||
from collections import namedtuple
|
||||
from typing import (Any, Dict, List, Optional, Tuple,
|
||||
Union)
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from vllm.utils import supports_custom_op
|
||||
|
||||
from vllm.distributed.parallel_state import TensorMetadata
|
||||
|
||||
# memory recycler
|
||||
MEMORY_RECYCLER_KEY = ['previous_hidden_states']
|
||||
|
||||
def _split_tensor_dict_concat(
|
||||
tensor_dict: Dict[str, Union[torch.Tensor, Any]]
|
||||
) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
|
||||
"""Split the tensor dictionary into two parts:
|
||||
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
|
||||
by its metadata.
|
||||
2. A list of tensors.
|
||||
"""
|
||||
|
||||
# all_tensor_list = ['input_tokens','input_positions', 'slot_mapping','seq_lens_tensor', 'context_lens_tensor','block_tables', 'query_start_loc','seq_start_loc', 'selected_token_indices']
|
||||
metadata_list: List[Tuple[str, Any]] = []
|
||||
tensor_list: List[torch.Tensor] = []
|
||||
all_tensor = []
|
||||
all_tensor_numel = 0
|
||||
for key, value in tensor_dict.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
# Note: we cannot use `value.device` here,
|
||||
# because it contains not only the device type but also the device
|
||||
# index (e.g. "cuda:0"). We only need the device type.
|
||||
# receiving side will set the device index.
|
||||
device = value.device.type
|
||||
|
||||
if not value.is_cpu and value.numel() > 0:
|
||||
value_bytes_tensor = value.view(torch.int8)
|
||||
all_tensor.append(value_bytes_tensor.view([-1]))
|
||||
all_tensor_numel += value_bytes_tensor.numel()
|
||||
# tensor_list.append(value)
|
||||
|
||||
metadata_list.append(
|
||||
(key, TensorMetadata(device, value.dtype, value.size())))
|
||||
|
||||
else:
|
||||
metadata_list.append((key, value))
|
||||
if len(all_tensor) != 0:
|
||||
|
||||
memory_recycler_dynamic_output = None
|
||||
# 计算all_tensor的总大小
|
||||
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler, DeepseekMTPMemoryRecycler
|
||||
if isinstance(memory_recycler, DeepseekMTPMemoryRecycler):
|
||||
memory_recycler_dynamic_output = memory_recycler.DYNAMIC_OUTPUT_BUFFER.view(torch.int8)[:all_tensor_numel]
|
||||
|
||||
if memory_recycler_dynamic_output is not None:
|
||||
all_tensor = torch.concatenate(all_tensor, 0, out = memory_recycler_dynamic_output)
|
||||
else:
|
||||
all_tensor = torch.concatenate(all_tensor, 0)
|
||||
|
||||
tensor_list.append(all_tensor)
|
||||
metadata_list.append(("all_tensor", TensorMetadata(all_tensor.device.type, all_tensor.dtype, all_tensor.size())))
|
||||
|
||||
return metadata_list, tensor_list
|
||||
|
||||
def all_gather_to_rank0(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
world_size = self.world_size
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
assert -input_.dim() <= dim < input_.dim(), (
|
||||
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
||||
|
||||
# For TPUs, use TPU communicator.
|
||||
tpu_comm = self.tpu_communicator
|
||||
if tpu_comm is not None and not tpu_comm.disabled:
|
||||
return tpu_comm.all_gather(input_, dim)
|
||||
|
||||
# For HPUs, use HPU communicator.
|
||||
hpu_comm = self.hpu_communicator
|
||||
if hpu_comm is not None and not hpu_comm.disabled:
|
||||
return hpu_comm.all_gather(input_, dim)
|
||||
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
input_size = input_.size()
|
||||
# NOTE: we have to use concat-style all-gather here,
|
||||
# stack-style all-gather has compatibility issues with
|
||||
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
|
||||
output_size = (input_size[0] * world_size, ) + input_size[1:]
|
||||
|
||||
try:
|
||||
total_bytes = input_.numel() * input_.element_size() * world_size
|
||||
# only support 4M now
|
||||
if total_bytes < 4194304:
|
||||
from torch_vacc.vacc.custom_ops import all_gather
|
||||
output_tensor = all_gather(input_, self.rank_in_group, self.world_size, self.group_id,
|
||||
dev_info = self.rank_device_infos)
|
||||
|
||||
if self.rank_in_group != 0:
|
||||
output_tensor = None
|
||||
else:
|
||||
output_tensor = output_tensor.movedim(0, dim)
|
||||
output_tensor = output_tensor.reshape(input_size[:dim] +
|
||||
(world_size *
|
||||
input_size[dim], ) +
|
||||
input_size[dim + 1:])
|
||||
|
||||
return output_tensor
|
||||
except Exception as e:
|
||||
print("all_gather by DSP run Fail, now use vccl-ops", e, input_.shape, input_.dtype)
|
||||
|
||||
# Allocate output tensor.
|
||||
output_tensor = torch.empty(output_size,
|
||||
dtype=input_.dtype,
|
||||
device=input_.device)
|
||||
# All-gather.
|
||||
torch.distributed.all_gather_into_tensor(output_tensor,
|
||||
input_,
|
||||
group=self.device_group)
|
||||
if self.rank_in_group != 0:
|
||||
output_tensor = None
|
||||
else:
|
||||
# Reshape
|
||||
output_tensor = output_tensor.reshape((world_size, ) + input_size)
|
||||
output_tensor = output_tensor.movedim(0, dim)
|
||||
output_tensor = output_tensor.reshape(input_size[:dim] +
|
||||
(world_size *
|
||||
input_size[dim], ) +
|
||||
input_size[dim + 1:])
|
||||
return output_tensor
|
||||
|
||||
def generate_group_id(self, group_id):
|
||||
self.group_id = group_id
|
||||
|
||||
def generate_rank_device_infos(self):
|
||||
import numpy as np
|
||||
import os
|
||||
# encoder rank_dev_list
|
||||
def combine_arrays(a, b):
|
||||
a = np.asarray(a, dtype=np.uint32)
|
||||
b = np.asarray(b, dtype=np.uint32)
|
||||
|
||||
if len(a) != len(b):
|
||||
raise ValueError("两个数组的长度必须一致。")
|
||||
|
||||
a_shifted = np.left_shift(a, 16)
|
||||
combined = np.bitwise_or(a_shifted, b)
|
||||
return combined.tolist()
|
||||
|
||||
# decoder rank_dev_list
|
||||
def uncombine_array(array):
|
||||
array = np.asarray(array, dtype=np.uint32)
|
||||
o_0 = array >> 16
|
||||
o_1 = array << 16 >> 16
|
||||
return o_0, o_1
|
||||
|
||||
physical_devices = self.ranks
|
||||
visible_devices = os.getenv('VACC_VISIBLE_DEVICES')
|
||||
|
||||
if visible_devices is not None:
|
||||
device_list = visible_devices.split(',')
|
||||
device_count = len(device_list)
|
||||
assert device_count >= len(self.ranks), f'VACC_VISIBLE_DEVICES:{device_count} is less than ranks:{len(self.ranks)}, please designate more devices'
|
||||
physical_devices = [int(device_list[i]) for i in self.ranks]
|
||||
# print("[vccl] logic_devices:physical_devices ", self.ranks, physical_devices)
|
||||
|
||||
logic_ranks = [self.ranks.index(rank) for rank in self.ranks]
|
||||
self.rank_device_infos = combine_arrays(logic_ranks, physical_devices)
|
||||
|
||||
def get_bitwidth(dtype):
|
||||
if dtype.is_floating_point:
|
||||
return torch.finfo(dtype).bits
|
||||
else:
|
||||
return torch.iinfo(dtype).bits
|
||||
|
||||
class GroupCoordinator:
|
||||
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
User-facing all-reduce function before we actually call the
|
||||
all-reduce operation.
|
||||
|
||||
We need this because Dynamo does not support passing an arbitrary
|
||||
object (`self` in this case) to a custom op. We need to pass the
|
||||
group name as a string, and then look up the group coordinator from
|
||||
the group name, dispatch the all-reduce operation to the group
|
||||
coordinator.
|
||||
|
||||
In addition, PyTorch custom ops do not support mutation or returning
|
||||
a new tensor in the same op. So we need to figure out if the op is
|
||||
in-place or out-of-place ahead of time.
|
||||
"""
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if self.world_size == 1:
|
||||
return input_
|
||||
if input_.is_cpu:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
ipex.distributed.all_reduce(input_, group=self.device_group)
|
||||
return input_
|
||||
|
||||
# vacc impl
|
||||
# s0, s1 = input_.shape
|
||||
# output_tensor = torch.empty([32, s0, s1],
|
||||
# dtype=input_.dtype,
|
||||
# device=input_.device)
|
||||
# torch.distributed.all_gather_into_tensor(output_tensor,
|
||||
# input_,
|
||||
# group=self.device_group)
|
||||
# input_ = output_tensor.sum(dim=0)
|
||||
torch.distributed.all_reduce(input_, group=self.device_group)
|
||||
return input_
|
||||
|
||||
def broadcast_tensor_dict(
|
||||
self,
|
||||
tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None,
|
||||
src: int = 0,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
metadata_group: Optional[ProcessGroup] = None
|
||||
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
|
||||
"""Broadcast the input tensor dictionary.
|
||||
NOTE: `src` is the local rank of the source rank.
|
||||
"""
|
||||
|
||||
# all_tensor_list = ['input_tokens','input_positions', 'slot_mapping','seq_lens_tensor', 'context_lens_tensor','block_tables', 'query_start_loc','seq_start_loc', 'selected_token_indices']
|
||||
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if (not torch.distributed.is_initialized() or self.world_size == 1):
|
||||
return tensor_dict
|
||||
|
||||
group = self.device_group
|
||||
metadata_group = self.cpu_group
|
||||
assert src < self.world_size, f"Invalid src rank ({src})"
|
||||
|
||||
rank_in_group = self.rank_in_group
|
||||
|
||||
if rank_in_group == src:
|
||||
metadata_list: List[Tuple[Any, Any]] = []
|
||||
assert isinstance(
|
||||
tensor_dict,
|
||||
dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
|
||||
# metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
|
||||
metadata_list, tensor_list = _split_tensor_dict_concat(tensor_dict)
|
||||
# `metadata_list` lives in CPU memory.
|
||||
# `broadcast_object_list` has serialization & deserialization,
|
||||
# all happening on CPU. Therefore, we can use the CPU group.
|
||||
# metadata_list 包含 (key, value) value是metadata 只有shape,没有数据
|
||||
self.broadcast_object(metadata_list, src=src)
|
||||
async_handles = []
|
||||
for tensor in tensor_list:
|
||||
if tensor.numel() == 0:
|
||||
# Skip broadcasting empty tensors.
|
||||
continue
|
||||
if tensor.is_cpu:
|
||||
# use metadata_group for CPU tensors
|
||||
handle = torch.distributed.broadcast(tensor,
|
||||
src=self.ranks[src],
|
||||
group=metadata_group,
|
||||
async_op=True)
|
||||
async_handles.append(handle)
|
||||
else:
|
||||
# use group for GPU tensors
|
||||
total_bytes = tensor.numel() * tensor.element_size()
|
||||
use_dist = True
|
||||
# only support 4M now
|
||||
if total_bytes < 4194304:
|
||||
try:
|
||||
from torch_vacc.vacc.custom_ops import broadcast
|
||||
#print("send tensor is:", tensor.shape, tensor.dtype, self.rank)
|
||||
broadcast(tensor, self.rank_in_group, self.world_size, root_rank=0, group_id=self.group_id,
|
||||
dev_info = self.rank_device_infos)
|
||||
use_dist = False
|
||||
except Exception as e:
|
||||
print("odsp broadcast run fail, now using distributed:", e)
|
||||
|
||||
if use_dist:
|
||||
handle = torch.distributed.broadcast(tensor,
|
||||
src=self.ranks[src],
|
||||
group=group,
|
||||
async_op=True)
|
||||
async_handles.append(handle)
|
||||
|
||||
for async_handle in async_handles:
|
||||
async_handle.wait()
|
||||
|
||||
else:
|
||||
# other rank
|
||||
metadata_list = self.broadcast_object(None, src=src)
|
||||
tensor_dict = {}
|
||||
async_handles = []
|
||||
tensor_size = [] # list of [key, shape] for split all_tensor
|
||||
dataType_list = []
|
||||
for key, value in metadata_list:
|
||||
# if rank_in_group == 1:
|
||||
# print('rank1 k v ', key, value)
|
||||
if isinstance(value, TensorMetadata):
|
||||
tensor = None
|
||||
|
||||
# 固定为int8
|
||||
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler, DeepseekMTPMemoryRecycler
|
||||
if isinstance(memory_recycler, DeepseekMTPMemoryRecycler) and value.dtype == torch.int8:
|
||||
tensor = memory_recycler.DYNAMIC_OUTPUT_BUFFER.view(value.dtype)[:value.size.numel()].view(value.size)
|
||||
|
||||
if tensor is None:
|
||||
tensor = torch.empty(value.size,
|
||||
dtype=value.dtype,
|
||||
device=value.device)
|
||||
if tensor.numel() == 0:
|
||||
# Skip broadcasting empty tensors.
|
||||
tensor_dict[key] = tensor
|
||||
continue
|
||||
if tensor.is_cpu:
|
||||
# use metadata_group for CPU tensors
|
||||
handle = torch.distributed.broadcast(
|
||||
tensor,
|
||||
src=self.ranks[src],
|
||||
group=metadata_group,
|
||||
async_op=True)
|
||||
async_handles.append(handle)
|
||||
tensor_dict[key] = tensor
|
||||
else:
|
||||
# use group for GPU tensors
|
||||
if key == "all_tensor":
|
||||
total_bytes = tensor.numel() * tensor.element_size()
|
||||
use_dist = True
|
||||
# only support 4M now
|
||||
if total_bytes < 4194304:
|
||||
try:
|
||||
from torch_vacc.vacc.custom_ops import broadcast
|
||||
tensor = broadcast(tensor, self.rank_in_group, self.world_size, root_rank=0, group_id=self.group_id,
|
||||
dev_info = self.rank_device_infos)
|
||||
use_dist = False
|
||||
except Exception as e:
|
||||
print("dsp brocast run fail, now using distributed:", e)
|
||||
|
||||
if use_dist:
|
||||
handle = torch.distributed.broadcast(
|
||||
tensor, #拼接的tensor
|
||||
src=self.ranks[src],
|
||||
group=group,
|
||||
async_op=False)
|
||||
|
||||
# 按 key shape对, 拆分 all_tensor, 存入tensor_dict
|
||||
start = 0
|
||||
idx = 0
|
||||
for ki_vi in tensor_size:
|
||||
ki, vi = ki_vi
|
||||
length = vi.numel() * int(get_bitwidth(dataType_list[idx]) / 8)
|
||||
if ki in MEMORY_RECYCLER_KEY:
|
||||
tensor_dict[ki] = tensor[start:start+length].view(dataType_list[idx]).view(vi)
|
||||
else:
|
||||
value_tensor = torch.empty(vi,
|
||||
dtype=dataType_list[idx],
|
||||
device=value.device)
|
||||
recv_tensor = tensor[start:start+length].view(dataType_list[idx]).view(vi)
|
||||
tensor_dict[ki] = value_tensor.copy_(recv_tensor)
|
||||
start += length
|
||||
idx += 1
|
||||
else:
|
||||
dataType_list.append(value.dtype)
|
||||
tensor_size.append([key, value.size])
|
||||
else:
|
||||
tensor_dict[key] = value
|
||||
for async_handle in async_handles:
|
||||
async_handle.wait()
|
||||
return tensor_dict
|
||||
|
||||
def all_gather(self, input_: torch.Tensor, dim: int = -1, output_: torch.Tensor = None) -> torch.Tensor:
|
||||
world_size = self.world_size
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
assert -input_.dim() <= dim < input_.dim(), (
|
||||
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
||||
|
||||
if self.use_custom_op_call:
|
||||
return torch.ops.vllm.all_gather(input_,
|
||||
dim,
|
||||
world_size,
|
||||
group_name=self.unique_name)
|
||||
else:
|
||||
# 启用输出复用版的 all_gather
|
||||
if output_ is not None:
|
||||
return self.device_communicator.all_gather_into_tensor(input_, dim, output_)
|
||||
return self._all_gather_out_place(input_, dim)
|
||||
|
||||
def recv_tensor_dict(
|
||||
self,
|
||||
src: Optional[int] = None,
|
||||
all_gather_group: Optional["GroupCoordinator"] = None,
|
||||
) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
|
||||
"""Recv the input tensor dictionary.
|
||||
NOTE: `src` is the local rank of the source rank.
|
||||
"""
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if not torch.distributed.is_initialized() or self.world_size == 1:
|
||||
return None
|
||||
all_gather_size = (1 if all_gather_group is None else
|
||||
all_gather_group.world_size)
|
||||
all_gather_rank = (0 if all_gather_group is None else
|
||||
all_gather_group.rank_in_group)
|
||||
|
||||
group = self.device_group
|
||||
metadata_group = self.cpu_group
|
||||
|
||||
if src is None:
|
||||
src = (self.rank_in_group - 1) % self.world_size
|
||||
assert src < self.world_size, f"Invalid src rank ({src})"
|
||||
|
||||
recv_metadata_list = self.recv_object(src=src)
|
||||
tensor_dict: dict[str, Any] = {}
|
||||
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import alloc_pipeline_parallel_recycler_buffer
|
||||
memory_recycler_list = ["hidden_states", "residual"]
|
||||
for key, value in recv_metadata_list:
|
||||
if isinstance(value, TensorMetadata):
|
||||
# 判断是否需要内存复用
|
||||
# 1. key在[hiddens, residual]中,说明为PP造成的
|
||||
# 2. 可以根据key,从 memory_recycling 模块中申请到tensor
|
||||
use_create_recycler_tensor = False
|
||||
tensor = None
|
||||
value_tensor = None # 用于接收all_gather总数据
|
||||
if key in memory_recycler_list:
|
||||
tensor = alloc_pipeline_parallel_recycler_buffer(value.size, value.dtype, key)
|
||||
if tensor is not None:
|
||||
use_create_recycler_tensor = True
|
||||
|
||||
if not use_create_recycler_tensor:
|
||||
tensor = torch.empty(value.size,
|
||||
dtype=value.dtype,
|
||||
device=value.device)
|
||||
|
||||
value_tensor = tensor
|
||||
|
||||
if tensor.numel() == 0:
|
||||
# Skip broadcasting empty tensors.
|
||||
tensor_dict[key] = tensor
|
||||
continue
|
||||
|
||||
# send-allgather: send only a slice, then do allgather.
|
||||
use_all_gather = (all_gather_group is not None
|
||||
and tensor.numel() % all_gather_size == 0)
|
||||
|
||||
if use_all_gather:
|
||||
orig_shape = tensor.shape
|
||||
# 内存复用,无需reshape, view即可
|
||||
if use_create_recycler_tensor:
|
||||
tensor = tensor.view(all_gather_size,
|
||||
-1)[all_gather_rank].contiguous()
|
||||
else:
|
||||
tensor = tensor.reshape(all_gather_size,
|
||||
-1)[all_gather_rank]
|
||||
|
||||
if tensor.is_cpu:
|
||||
# use metadata_group for CPU tensors
|
||||
torch.distributed.recv(tensor,
|
||||
src=self.ranks[src],
|
||||
group=metadata_group)
|
||||
else:
|
||||
# use group for GPU tensors
|
||||
torch.distributed.recv(tensor,
|
||||
src=self.ranks[src],
|
||||
group=group)
|
||||
if use_all_gather:
|
||||
# do the allgather
|
||||
if use_create_recycler_tensor:
|
||||
tensor = all_gather_group.all_gather( # type: ignore
|
||||
tensor, dim=0, output_ = value_tensor)
|
||||
tensor = tensor.view(orig_shape)
|
||||
else:
|
||||
tensor = all_gather_group.all_gather( # type: ignore
|
||||
tensor, dim=0)
|
||||
tensor = tensor.reshape(orig_shape)
|
||||
|
||||
tensor_dict[key] = tensor
|
||||
else:
|
||||
tensor_dict[key] = value
|
||||
return tensor_dict
|
||||
0
vllm_vacc/vllm/engine/__init__.py
Normal file
0
vllm_vacc/vllm/engine/__init__.py
Normal file
BIN
vllm_vacc/vllm/engine/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/engine/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_vacc/vllm/engine/__pycache__/arg_utils.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/engine/__pycache__/arg_utils.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_vacc/vllm/engine/__pycache__/llm_engine.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/engine/__pycache__/llm_engine.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_vacc/vllm/engine/__pycache__/metrics.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/engine/__pycache__/metrics.cpython-312.pyc
Normal file
Binary file not shown.
158
vllm_vacc/vllm/engine/arg_utils.py
Normal file
158
vllm_vacc/vllm/engine/arg_utils.py
Normal file
@@ -0,0 +1,158 @@
|
||||
import argparse
|
||||
import copy
|
||||
import dataclasses
|
||||
import functools
|
||||
import json
|
||||
import sys
|
||||
import threading
|
||||
import warnings
|
||||
from dataclasses import MISSING, dataclass, fields, is_dataclass
|
||||
from itertools import permutations
|
||||
from typing import (Annotated, Any, Callable, Dict, List, Literal, Optional,
|
||||
Type, TypeVar, Union, cast, get_args, get_origin)
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
from typing_extensions import TypeIs, deprecated
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.plugins import load_general_plugins
|
||||
from vllm.reasoning import ReasoningParserManager
|
||||
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
|
||||
from vllm.transformers_utils.utils import check_gguf_file
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import (FlexibleArgumentParser, GiB_bytes, get_ip,
|
||||
is_in_ray_actor)
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.platforms import CpuArchEnum, current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
def _set_default_args(self, usage_context: UsageContext,
|
||||
model_config: ModelConfig) -> None:
|
||||
"""Set Default Arguments for V1 Engine."""
|
||||
|
||||
# V1 always uses chunked prefills and prefix caching
|
||||
# for non-pooling tasks.
|
||||
# For pooling tasks the default is False
|
||||
self.enable_chunked_prefill = False
|
||||
self.enable_prefix_caching = False
|
||||
if model_config.runner_type != "pooling":
|
||||
# TODO: When prefix caching supports prompt embeds inputs, this
|
||||
# check can be removed.
|
||||
if (self.enable_prompt_embeds
|
||||
and self.enable_prefix_caching is not False):
|
||||
logger.warning(
|
||||
"--enable-prompt-embeds and --enable-prefix-caching "
|
||||
"are not supported together in V1. Prefix caching has "
|
||||
"been disabled.")
|
||||
|
||||
# V1 should use the new scheduler by default.
|
||||
# Swap it only if this arg is set to the original V0 default
|
||||
if self.scheduler_cls == EngineArgs.scheduler_cls:
|
||||
self.scheduler_cls = "vllm.v1.core.sched.scheduler.Scheduler"
|
||||
|
||||
# When no user override, set the default values based on the usage
|
||||
# context.
|
||||
# Use different default values for different hardware.
|
||||
|
||||
# Try to query the device name on the current platform. If it fails,
|
||||
# it may be because the platform that imports vLLM is not the same
|
||||
# as the platform that vLLM is running on (e.g. the case of scaling
|
||||
# vLLM with Ray) and has no GPUs. In this case we use the default
|
||||
# values for non-H100/H200 GPUs.
|
||||
try:
|
||||
device_memory = current_platform.get_device_total_memory()
|
||||
device_name = current_platform.get_device_name().lower()
|
||||
except Exception:
|
||||
# This is only used to set default_max_num_batched_tokens
|
||||
device_memory = 0
|
||||
|
||||
# NOTE(Kuntai): Setting large `max_num_batched_tokens` for A100 reduces
|
||||
# throughput, see PR #17885 for more details.
|
||||
# So here we do an extra device name check to prevent such regression.
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
if device_memory >= 70 * GiB_bytes and "a100" not in device_name:
|
||||
# For GPUs like H100 and MI300x, use larger default values.
|
||||
default_max_num_batched_tokens = {
|
||||
UsageContext.LLM_CLASS: 16384,
|
||||
UsageContext.OPENAI_API_SERVER: 8192,
|
||||
}
|
||||
default_max_num_seqs = {
|
||||
UsageContext.LLM_CLASS: 1024,
|
||||
UsageContext.OPENAI_API_SERVER: 1024,
|
||||
}
|
||||
else:
|
||||
# TODO(woosuk): Tune the default values for other hardware.
|
||||
default_max_num_batched_tokens = {
|
||||
UsageContext.LLM_CLASS: 8192,
|
||||
UsageContext.OPENAI_API_SERVER: 2048,
|
||||
}
|
||||
default_max_num_seqs = {
|
||||
UsageContext.LLM_CLASS: 4,
|
||||
UsageContext.OPENAI_API_SERVER: 4,
|
||||
}
|
||||
|
||||
# tpu specific default values.
|
||||
if current_platform.is_tpu():
|
||||
default_max_num_batched_tokens_tpu = {
|
||||
UsageContext.LLM_CLASS: {
|
||||
'V6E': 2048,
|
||||
'V5E': 1024,
|
||||
'V5P': 512,
|
||||
},
|
||||
UsageContext.OPENAI_API_SERVER: {
|
||||
'V6E': 1024,
|
||||
'V5E': 512,
|
||||
'V5P': 256,
|
||||
}
|
||||
}
|
||||
|
||||
# cpu specific default values.
|
||||
if current_platform.is_cpu():
|
||||
world_size = self.pipeline_parallel_size * self.tensor_parallel_size
|
||||
default_max_num_batched_tokens = {
|
||||
UsageContext.LLM_CLASS: 4096 * world_size,
|
||||
UsageContext.OPENAI_API_SERVER: 2048 * world_size,
|
||||
}
|
||||
default_max_num_seqs = {
|
||||
UsageContext.LLM_CLASS: 256 * world_size,
|
||||
UsageContext.OPENAI_API_SERVER: 128 * world_size,
|
||||
}
|
||||
|
||||
use_context_value = usage_context.value if usage_context else None
|
||||
if (self.max_num_batched_tokens is None
|
||||
and usage_context in default_max_num_batched_tokens):
|
||||
if current_platform.is_tpu():
|
||||
chip_name = current_platform.get_device_name()
|
||||
if chip_name in default_max_num_batched_tokens_tpu[
|
||||
usage_context]:
|
||||
self.max_num_batched_tokens = \
|
||||
default_max_num_batched_tokens_tpu[
|
||||
usage_context][chip_name]
|
||||
else:
|
||||
self.max_num_batched_tokens = \
|
||||
default_max_num_batched_tokens[usage_context]
|
||||
else:
|
||||
if not self.enable_chunked_prefill:
|
||||
self.max_num_batched_tokens = model_config.max_model_len
|
||||
else:
|
||||
self.max_num_batched_tokens = \
|
||||
default_max_num_batched_tokens[usage_context]
|
||||
logger.debug(
|
||||
"Setting max_num_batched_tokens to %d for %s usage context.",
|
||||
self.max_num_batched_tokens, use_context_value)
|
||||
|
||||
if (self.max_num_seqs is None
|
||||
and usage_context in default_max_num_seqs):
|
||||
self.max_num_seqs = min(default_max_num_seqs[usage_context],
|
||||
self.max_num_batched_tokens or sys.maxsize)
|
||||
|
||||
logger.debug("Setting max_num_seqs to %d for %s usage context.",
|
||||
self.max_num_seqs, use_context_value)
|
||||
|
||||
49
vllm_vacc/vllm/engine/llm_engine.py
Normal file
49
vllm_vacc/vllm/engine/llm_engine.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from typing import Dict, Optional
|
||||
from vllm.engine.metrics_types import StatLoggerBase
|
||||
|
||||
class LLMEngine:
|
||||
|
||||
@classmethod
|
||||
def from_engine_args(
|
||||
cls,
|
||||
engine_args: EngineArgs,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
|
||||
) -> "LLMEngine":
|
||||
"""Creates an LLM engine from the engine arguments."""
|
||||
# Create the engine configs.
|
||||
vllm_config = engine_args.create_engine_config(usage_context)
|
||||
#patch to prevent num_speculative_tokens > 1
|
||||
speculative_mode = hasattr(vllm_config, 'speculative_config')
|
||||
if speculative_mode and \
|
||||
hasattr(vllm_config.speculative_config, 'num_speculative_tokens') and \
|
||||
vllm_config.speculative_config.num_speculative_tokens != 1:
|
||||
raise ValueError(f'run_mp_engine: only support num_speculative_tokens == 1, but get {vllm_config.speculative_config.num_speculative_tokens}')
|
||||
|
||||
default_model_infos = "default"
|
||||
if speculative_mode:
|
||||
if hasattr(vllm_config.speculative_config, 'method'):
|
||||
default_model_infos = vllm_config.speculative_config.method
|
||||
|
||||
from vllm_vacc.vllm.config_manager import vllm_vacc_config_manager
|
||||
vllm_vacc_config_manager().update_model_infos(default_model_infos)
|
||||
|
||||
import vllm.envs as envs
|
||||
engine_cls = None
|
||||
if envs.VLLM_USE_V1:
|
||||
from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
|
||||
engine_cls = V1LLMEngine
|
||||
else:
|
||||
from vllm.engine.llm_engine import LLMEngine as DefaultEngine
|
||||
engine_cls = DefaultEngine
|
||||
|
||||
assert engine_cls is not None, f"LLMEngine is empty: {engine_cls}"
|
||||
|
||||
return engine_cls.from_vllm_config(
|
||||
vllm_config=vllm_config,
|
||||
usage_context=usage_context,
|
||||
stat_loggers=stat_loggers,
|
||||
disable_log_stats=engine_args.disable_log_stats,
|
||||
)
|
||||
69
vllm_vacc/vllm/engine/metrics.py
Normal file
69
vllm_vacc/vllm/engine/metrics.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from vllm.engine.metrics_types import (StatLoggerBase, Stats)
|
||||
import vllm_vacc.vllm.model_executor.models.vars as global_vars
|
||||
|
||||
class LoggingStatLogger(StatLoggerBase):
|
||||
"""LoggingStatLogger is used in LLMEngine to log to Stdout."""
|
||||
|
||||
def log(self, stats: Stats) -> None:
|
||||
from vllm.engine.metrics import local_interval_elapsed, get_throughput, logger
|
||||
"""Called by LLMEngine.
|
||||
Logs to Stdout every self.local_interval seconds."""
|
||||
|
||||
# Save tracked stats for token counters.
|
||||
self.num_prompt_tokens.append(stats.num_prompt_tokens_iter)
|
||||
self.num_generation_tokens.append(stats.num_generation_tokens_iter)
|
||||
|
||||
# Update spec decode metrics
|
||||
self.maybe_update_spec_decode_metrics(stats)
|
||||
|
||||
# Log locally every local_interval seconds.
|
||||
if local_interval_elapsed(stats.now, self.last_local_log,
|
||||
self.local_interval):
|
||||
# Compute summary metrics for tracked stats (and log them
|
||||
# to promethus if applicable).
|
||||
prompt_throughput = get_throughput(self.num_prompt_tokens,
|
||||
now=stats.now,
|
||||
last_log=self.last_local_log)
|
||||
generation_throughput = get_throughput(
|
||||
self.num_generation_tokens,
|
||||
now=stats.now,
|
||||
last_log=self.last_local_log)
|
||||
|
||||
log_fn = logger.info
|
||||
if not any((prompt_throughput, generation_throughput,
|
||||
self.last_prompt_throughput,
|
||||
self.last_generation_throughput)):
|
||||
# Avoid log noise on an idle production system
|
||||
log_fn = logger.debug
|
||||
|
||||
log_fn(
|
||||
"Avg prompt throughput: %.1f tokens/s, "
|
||||
"Avg generation throughput: %.1f tokens/s, "
|
||||
"Running: %d reqs, Swapped: %d reqs, "
|
||||
"Pending: %d reqs, GPU KV cache usage: %.1f%%, "
|
||||
"CPU KV cache usage: %.1f%%., "
|
||||
"Do sequences length: %s",
|
||||
prompt_throughput,
|
||||
generation_throughput,
|
||||
stats.num_running_sys,
|
||||
stats.num_swapped_sys,
|
||||
stats.num_waiting_sys,
|
||||
stats.gpu_cache_usage_sys * 100,
|
||||
stats.cpu_cache_usage_sys * 100,
|
||||
str(global_vars.DO_SEQ_LENS)
|
||||
)
|
||||
if (stats.cpu_prefix_cache_hit_rate >= 0
|
||||
or stats.gpu_prefix_cache_hit_rate >= 0):
|
||||
log_fn(
|
||||
"Prefix cache hit rate: GPU: %.2f%%, CPU: %.2f%%",
|
||||
stats.gpu_prefix_cache_hit_rate * 100,
|
||||
stats.cpu_prefix_cache_hit_rate * 100,
|
||||
)
|
||||
if self.spec_decode_metrics is not None:
|
||||
logger.debug(
|
||||
self._format_spec_decode_metrics_str(
|
||||
self.spec_decode_metrics))
|
||||
|
||||
self._reset(stats, prompt_throughput, generation_throughput)
|
||||
|
||||
|
||||
0
vllm_vacc/vllm/engine/multiprocessing/__init__.py
Normal file
0
vllm_vacc/vllm/engine/multiprocessing/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
103
vllm_vacc/vllm/engine/multiprocessing/engine.py
Normal file
103
vllm_vacc/vllm/engine/multiprocessing/engine.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR,
|
||||
RPCError,
|
||||
RPCProcessRequest,
|
||||
RPCAbortRequest)
|
||||
from vllm.config import VllmConfig
|
||||
import signal
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.config import (
|
||||
maybe_register_config_serialize_by_value)
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MQLLMEngine:
|
||||
|
||||
def _handle_process_request(self, request: RPCProcessRequest):
|
||||
"""Handle RPCProcessRequest by adding it to the LLMEngine."""
|
||||
request_id = request.request_id
|
||||
|
||||
if self._errored_with is not None:
|
||||
rpc_err = RPCError(request_id=request_id,
|
||||
is_engine_errored=True,
|
||||
exception=ENGINE_DEAD_ERROR(self._errored_with))
|
||||
self._send_outputs(rpc_err)
|
||||
|
||||
try:
|
||||
self.engine.add_request(
|
||||
request_id=request_id,
|
||||
prompt=request.prompt,
|
||||
params=request.params,
|
||||
lora_request=request.lora_request,
|
||||
trace_headers=request.trace_headers,
|
||||
prompt_adapter_request=request.prompt_adapter_request,
|
||||
priority=request.priority)
|
||||
|
||||
if self.log_requests:
|
||||
from vllm.engine.multiprocessing.engine import logger
|
||||
|
||||
if request.prompt.get('prompt_token_ids') is not None:
|
||||
# logger.info("Added request: %s, %s, prompt length: %s", request.request_id, request.prompt['prompt_token_ids'], len(request.prompt['prompt_token_ids']))
|
||||
logger.info("Added request: %s, prompt length: %s", request.request_id, len(request.prompt['prompt_token_ids']))
|
||||
else:
|
||||
logger.info("Added request %s.", request.request_id)
|
||||
|
||||
except Exception as e:
|
||||
# We do not set self._errored = True here, since the error
|
||||
# is due to an issue adding this request to the engine,
|
||||
# rather than an issue with the engine itself.
|
||||
is_errored = self._errored_with is not None
|
||||
rpc_err = RPCError(request_id=request_id,
|
||||
is_engine_errored=is_errored,
|
||||
exception=e)
|
||||
self._send_outputs(rpc_err)
|
||||
|
||||
# Remove request from the engine.
|
||||
self.engine.abort_request(request_id)
|
||||
|
||||
def _handle_abort_request(self, request: RPCAbortRequest):
|
||||
self.engine.abort_request(request.request_id)
|
||||
if self.log_requests:
|
||||
from vllm.engine.multiprocessing.engine import logger
|
||||
import vllm_vacc.vllm.model_executor.models.vars as global_vars
|
||||
logger.info("Aborted request: %s, prompt length: %s", request.request_id, global_vars.DO_SEQ_LENS)
|
||||
|
||||
def run_mp_engine(vllm_config: VllmConfig, usage_context: UsageContext,
|
||||
ipc_path: str, disable_log_stats: bool,
|
||||
disable_log_requests: bool, engine_alive):
|
||||
|
||||
#patch to prevent num_speculative_tokens > 1
|
||||
speculative_mode = hasattr(vllm_config, 'speculative_config')
|
||||
if speculative_mode and \
|
||||
hasattr(vllm_config.speculative_config, 'num_speculative_tokens') and \
|
||||
vllm_config.speculative_config.num_speculative_tokens != 1:
|
||||
raise ValueError(f'run_mp_engine: only support num_speculative_tokens == 1, but get {vllm_config.speculative_config.num_speculative_tokens}')
|
||||
|
||||
default_model_infos = "default"
|
||||
if speculative_mode:
|
||||
if hasattr(vllm_config.speculative_config, 'method'):
|
||||
default_model_infos = vllm_config.speculative_config.method
|
||||
|
||||
from vllm_vacc.vllm.config_manager import vllm_vacc_config_manager
|
||||
vllm_vacc_config_manager().update_model_infos(default_model_infos)
|
||||
|
||||
try:
|
||||
# Ensure we can serialize transformer config before spawning
|
||||
maybe_register_config_serialize_by_value()
|
||||
from vllm.engine.multiprocessing.engine import MQLLMEngine,signal_handler
|
||||
engine = MQLLMEngine.from_vllm_config(
|
||||
vllm_config=vllm_config,
|
||||
usage_context=usage_context,
|
||||
disable_log_stats=disable_log_stats,
|
||||
disable_log_requests=disable_log_requests,
|
||||
ipc_path=ipc_path)
|
||||
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
engine.start()
|
||||
|
||||
except BaseException as e:
|
||||
logger.exception(e)
|
||||
engine_alive.value = False
|
||||
raise e
|
||||
0
vllm_vacc/vllm/entrypoints/__init__.py
Normal file
0
vllm_vacc/vllm/entrypoints/__init__.py
Normal file
BIN
vllm_vacc/vllm/entrypoints/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/entrypoints/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_vacc/vllm/entrypoints/__pycache__/llm.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/entrypoints/__pycache__/llm.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_vacc/vllm/entrypoints/__pycache__/renderer.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/entrypoints/__pycache__/renderer.cpython-312.pyc
Normal file
Binary file not shown.
102
vllm_vacc/vllm/entrypoints/llm.py
Normal file
102
vllm_vacc/vllm/entrypoints/llm.py
Normal file
@@ -0,0 +1,102 @@
|
||||
|
||||
|
||||
import itertools
|
||||
import warnings
|
||||
from collections.abc import Sequence
|
||||
from contextlib import contextmanager
|
||||
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union,
|
||||
cast, overload)
|
||||
|
||||
import cloudpickle
|
||||
import torch.nn as nn
|
||||
from pydantic import ValidationError
|
||||
from tqdm.auto import tqdm
|
||||
from typing_extensions import TypeVar, deprecated
|
||||
|
||||
from vllm.entrypoints.utils import _validate_truncation_size
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import (RequestOutputKind, SamplingParams)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_R = TypeVar("_R", default=Any)
|
||||
|
||||
class LLM:
|
||||
|
||||
EPRECATE_LEGACY: ClassVar[bool] = True
|
||||
def _validate_and_add_requests(
|
||||
self,
|
||||
prompts: Union[PromptType, Sequence[PromptType]],
|
||||
params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
|
||||
Sequence[PoolingParams]],
|
||||
*,
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
|
||||
priority: Optional[list[int]] = None,
|
||||
) -> None:
|
||||
|
||||
if isinstance(prompts, (str, dict)):
|
||||
# Convert a single prompt to a list.
|
||||
prompts = [prompts]
|
||||
|
||||
num_requests = len(prompts)
|
||||
if isinstance(params, Sequence) and len(params) != num_requests:
|
||||
raise ValueError("The lengths of prompts and params "
|
||||
"must be the same.")
|
||||
if isinstance(lora_request,
|
||||
Sequence) and len(lora_request) != num_requests:
|
||||
raise ValueError("The lengths of prompts and lora_request "
|
||||
"must be the same.")
|
||||
|
||||
for sp in params if isinstance(params, Sequence) else (params, ):
|
||||
if isinstance(sp, SamplingParams):
|
||||
# We only care about the final output
|
||||
sp.output_kind = RequestOutputKind.FINAL_ONLY
|
||||
|
||||
# Add requests to the engine.
|
||||
it = prompts
|
||||
if use_tqdm:
|
||||
tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
|
||||
it = tqdm_func(it, desc="Adding requests")
|
||||
|
||||
if (hasattr(current_platform, 'supports_v1') and current_platform.supports_v1(current_platform)):
|
||||
batch_items = []
|
||||
model_config = self.llm_engine.model_config
|
||||
for i, prompt in enumerate(it):
|
||||
request_id = str(next(self.request_counter))
|
||||
# print("requset_id===========", request_id)
|
||||
param = params[i] if isinstance(params, Sequence) else params
|
||||
tokenization_kwargs: dict[str, Any] = {}
|
||||
_validate_truncation_size(model_config.max_model_len,
|
||||
param.truncate_prompt_tokens,
|
||||
tokenization_kwargs)
|
||||
|
||||
batch_items.append((
|
||||
request_id,
|
||||
prompt,
|
||||
params[i] if isinstance(params, Sequence) else params,
|
||||
None, # arrival_time,不用的话传 None
|
||||
(lora_request[i] if isinstance(lora_request, Sequence)
|
||||
else lora_request),
|
||||
tokenization_kwargs,
|
||||
None, # trace_headers(如无 APM/Tracing,None)
|
||||
(priority[i] if priority else 0),
|
||||
))
|
||||
# 一次性下发给 EngineCore(走 ADD_BULK)
|
||||
self.llm_engine.add_requests(batch_items)
|
||||
else:
|
||||
for i, prompt in enumerate(it):
|
||||
self._add_request(
|
||||
prompt,
|
||||
params[i] if isinstance(params, Sequence) else params,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request[i] if isinstance(
|
||||
lora_request, Sequence) else lora_request,
|
||||
priority=priority[i] if priority else 0,
|
||||
)
|
||||
0
vllm_vacc/vllm/entrypoints/openai/__init__.py
Normal file
0
vllm_vacc/vllm/entrypoints/openai/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
345
vllm_vacc/vllm/entrypoints/openai/serving_completion.py
Normal file
345
vllm_vacc/vllm/entrypoints/openai/serving_completion.py
Normal file
@@ -0,0 +1,345 @@
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from collections.abc import Sequence as GenericSequence
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
import jinja2
|
||||
from fastapi import Request
|
||||
from typing_extensions import assert_never
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseChoice,
|
||||
CompletionResponseStreamChoice,
|
||||
CompletionStreamResponse,
|
||||
ErrorResponse,
|
||||
RequestResponseMetadata,
|
||||
UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import (
|
||||
EmbedsPrompt as ServingEngineEmbedsPrompt)
|
||||
from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
|
||||
TextTokensPrompt,
|
||||
clamp_prompt_logprobs,
|
||||
is_text_tokens_prompt)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.utils import get_max_tokens
|
||||
from vllm.inputs.data import (EmbedsPrompt, TokensPrompt, is_embeds_prompt,
|
||||
is_tokens_prompt)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import merge_async_iterators
|
||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||
|
||||
from vllm.entrypoints.openai.serving_completion import logger
|
||||
from vllm.utils import (is_list_of, make_async, merge_async_iterators,
|
||||
random_uuid)
|
||||
from vllm.utils import (AsyncMicrobatchTokenizer, is_list_of,
|
||||
merge_async_iterators, random_uuid)
|
||||
from vllm_vacc.vllm.model_executor.models.vars import LLM_MAX_PREFILL_SEQ_LEN
|
||||
|
||||
|
||||
class OpenAIServingCompletion(OpenAIServing):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
enable_prompt_tokens_details: bool = False,
|
||||
enable_force_include_usage: bool = False,
|
||||
enable_strict_batch_barrier: bool = True,
|
||||
log_error_stack: bool = False,
|
||||
):
|
||||
|
||||
self.engine_client = engine_client
|
||||
self.model_config = model_config
|
||||
self.max_model_len = model_config.max_model_len
|
||||
|
||||
self.models = models
|
||||
|
||||
self.request_logger = request_logger
|
||||
self.return_tokens_as_token_ids = return_tokens_as_token_ids
|
||||
self.enable_force_include_usage = enable_force_include_usage
|
||||
|
||||
self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
self._async_tokenizer_pool: dict[AnyTokenizer,
|
||||
AsyncMicrobatchTokenizer] = {}
|
||||
self.log_error_stack = log_error_stack
|
||||
|
||||
self.enable_prompt_tokens_details = enable_prompt_tokens_details
|
||||
self.default_sampling_params = (
|
||||
self.model_config.get_diff_sampling_param())
|
||||
if self.default_sampling_params:
|
||||
source = self.model_config.generation_config
|
||||
source = "model" if source == "auto" else source
|
||||
logger.info("Using default completion sampling params from %s: %s",
|
||||
source, self.default_sampling_params)
|
||||
self.enable_strict_batch_barrier = enable_strict_batch_barrier
|
||||
|
||||
|
||||
async def create_completion(
|
||||
self,
|
||||
request: CompletionRequest,
|
||||
raw_request: Optional[Request] = None,
|
||||
) -> Union[AsyncGenerator[str, None], CompletionResponse, ErrorResponse]:
|
||||
"""Completion API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/completions/create
|
||||
for the API specification. This API mimics the OpenAI Completion API.
|
||||
|
||||
NOTE: Currently we do not support the following feature:
|
||||
- suffix (the language models we currently support do not support
|
||||
suffix)
|
||||
"""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
# If the engine is dead, raise the engine's DEAD_ERROR.
|
||||
# This is required for the streaming case, where we return a
|
||||
# success status before we actually start generating text :).
|
||||
if self.engine_client.errored:
|
||||
raise self.engine_client.dead_error
|
||||
|
||||
# Return error for unsupported features.
|
||||
if request.suffix is not None:
|
||||
return self.create_error_response(
|
||||
"suffix is not currently supported")
|
||||
|
||||
if request.echo and request.prompt_embeds is not None:
|
||||
return self.create_error_response(
|
||||
"Echo is unsupported with prompt embeds.")
|
||||
|
||||
if (request.prompt_logprobs is not None
|
||||
and request.prompt_embeds is not None):
|
||||
return self.create_error_response(
|
||||
"prompt_logprobs is not compatible with prompt embeds.")
|
||||
|
||||
request_id = (
|
||||
f"cmpl-"
|
||||
f"{self._base_request_id(raw_request, request.request_id)}")
|
||||
created_time = int(time.time())
|
||||
|
||||
request_metadata = RequestResponseMetadata(request_id=request_id)
|
||||
if raw_request:
|
||||
raw_request.state.request_metadata = request_metadata
|
||||
|
||||
try:
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
|
||||
if self.model_config.skip_tokenizer_init:
|
||||
tokenizer = None
|
||||
else:
|
||||
tokenizer = await self.engine_client.get_tokenizer()
|
||||
renderer = self._get_renderer(tokenizer)
|
||||
|
||||
engine_prompts = await renderer.render_prompt_and_embeds(
|
||||
prompt_or_prompts=request.prompt,
|
||||
prompt_embeds=request.prompt_embeds,
|
||||
deepstack_input_embeds=request.deepstack_input_embeds if hasattr(request, 'deepstack_input_embeds') else None,
|
||||
config=self._build_render_config(request),
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
except TypeError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
except RuntimeError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
except jinja2.TemplateError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: list[AsyncGenerator[RequestOutput, None]] = []
|
||||
try:
|
||||
total_num_prompts = len(engine_prompts)
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
sampling_params: Union[SamplingParams, BeamSearchParams]
|
||||
# Mypy does not infer that engine_prompt will have only one of
|
||||
# "prompt_token_ids" or "prompt_embeds" defined, and both of
|
||||
# these as Union[object, the expected type], where it infers
|
||||
# object if engine_prompt is a subclass of one of the
|
||||
# typeddicts that defines both keys. Worse, because of
|
||||
# https://github.com/python/mypy/issues/8586, mypy does not
|
||||
# infer the type of engine_prompt correctly because of the
|
||||
# enumerate. So we need an unnecessary cast here.
|
||||
engine_prompt = cast(Union[EmbedsPrompt, TokensPrompt],
|
||||
engine_prompt)
|
||||
if is_embeds_prompt(engine_prompt):
|
||||
input_length = len(engine_prompt["prompt_embeds"])
|
||||
elif is_tokens_prompt(engine_prompt):
|
||||
input_length = len(engine_prompt["prompt_token_ids"])
|
||||
if input_length > LLM_MAX_PREFILL_SEQ_LEN:
|
||||
raise ValueError(
|
||||
f"This model's maximum input seq length limit is "
|
||||
f"{LLM_MAX_PREFILL_SEQ_LEN} tokens. However, you requested "
|
||||
f"({input_length} in the input messages, "
|
||||
f"Please reduce the length of the input messages.")
|
||||
else:
|
||||
assert_never(engine_prompt)
|
||||
|
||||
if self.default_sampling_params is None:
|
||||
self.default_sampling_params = {}
|
||||
|
||||
max_tokens = get_max_tokens(
|
||||
max_model_len=self.max_model_len,
|
||||
request=request,
|
||||
input_length=input_length,
|
||||
default_sampling_params=self.default_sampling_params,
|
||||
)
|
||||
|
||||
if request.use_beam_search:
|
||||
sampling_params = request.to_beam_search_params(
|
||||
max_tokens, self.default_sampling_params)
|
||||
else:
|
||||
sampling_params = request.to_sampling_params(
|
||||
max_tokens,
|
||||
self.model_config.logits_processor_pattern,
|
||||
self.default_sampling_params,
|
||||
)
|
||||
|
||||
# Inject strict batch barrier metadata so this batch is held
|
||||
# until all items are ready, then scheduled together.
|
||||
if (self.enable_strict_batch_barrier
|
||||
and total_num_prompts > 1
|
||||
and isinstance(sampling_params, SamplingParams)):
|
||||
if sampling_params.extra_args is None:
|
||||
sampling_params.extra_args = {}
|
||||
sampling_params.extra_args.setdefault("barrier_group_id",
|
||||
request_id)
|
||||
sampling_params.extra_args.setdefault("barrier_group_size",
|
||||
total_num_prompts)
|
||||
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
self._log_inputs(
|
||||
request_id_item,
|
||||
engine_prompt,
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
trace_headers = (None if raw_request is None else await
|
||||
self._get_trace_headers(raw_request.headers))
|
||||
|
||||
# Mypy inconsistently requires this second cast in different
|
||||
# environments. It shouldn't be necessary (redundant from above)
|
||||
# but pre-commit in CI fails without it.
|
||||
engine_prompt = cast(Union[EmbedsPrompt, TokensPrompt],
|
||||
engine_prompt)
|
||||
if isinstance(sampling_params, BeamSearchParams):
|
||||
generator = self.engine_client.beam_search(
|
||||
prompt=engine_prompt,
|
||||
request_id=request_id,
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
else:
|
||||
generator = self.engine_client.generate(
|
||||
engine_prompt,
|
||||
sampling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
logger.error(e)
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
result_generator = merge_async_iterators(*generators)
|
||||
|
||||
model_name = self.models.model_name(lora_request)
|
||||
num_prompts = len(engine_prompts)
|
||||
|
||||
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
||||
# results. Noting that best_of is only supported in V0. In addition,
|
||||
# we do not stream the results when use beam search.
|
||||
stream = (request.stream
|
||||
and (request.best_of is None or request.n == request.best_of)
|
||||
and not request.use_beam_search)
|
||||
|
||||
# Streaming response
|
||||
if stream:
|
||||
return self.completion_stream_generator(
|
||||
request,
|
||||
engine_prompts,
|
||||
result_generator,
|
||||
request_id,
|
||||
created_time,
|
||||
model_name,
|
||||
num_prompts=num_prompts,
|
||||
tokenizer=tokenizer,
|
||||
request_metadata=request_metadata,
|
||||
enable_force_include_usage=self.enable_force_include_usage,
|
||||
)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: list[Optional[RequestOutput]] = [None] * num_prompts
|
||||
try:
|
||||
async for i, res in result_generator:
|
||||
final_res_batch[i] = res
|
||||
|
||||
for i, final_res in enumerate(final_res_batch):
|
||||
assert final_res is not None
|
||||
|
||||
# The output should contain the input text
|
||||
# We did not pass it into vLLM engine to avoid being redundant
|
||||
# with the inputs token IDs
|
||||
if final_res.prompt is None:
|
||||
engine_prompt = engine_prompts[i]
|
||||
final_res.prompt = None if is_embeds_prompt(
|
||||
engine_prompt) else engine_prompt.get("prompt")
|
||||
|
||||
final_res_batch_checked = cast(list[RequestOutput],
|
||||
final_res_batch)
|
||||
|
||||
response = self.request_output_to_completion_response(
|
||||
final_res_batch_checked,
|
||||
request,
|
||||
request_id,
|
||||
created_time,
|
||||
model_name,
|
||||
tokenizer,
|
||||
request_metadata,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
# When user requests streaming but we don't stream, we still need to
|
||||
# return a streaming response with a single event.
|
||||
if request.stream:
|
||||
response_json = response.model_dump_json()
|
||||
|
||||
async def fake_stream_generator() -> AsyncGenerator[str, None]:
|
||||
yield f"data: {response_json}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return fake_stream_generator()
|
||||
|
||||
return response
|
||||
191
vllm_vacc/vllm/entrypoints/openai/serving_engine.py
Normal file
191
vllm_vacc/vllm/entrypoints/openai/serving_engine.py
Normal file
@@ -0,0 +1,191 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import json
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from http import HTTPStatus
|
||||
from typing import (Any, Callable, Dict, Iterable, Iterator, List, Mapping,
|
||||
Optional, Sequence, Tuple, TypedDict, Union)
|
||||
|
||||
from fastapi import Request
|
||||
from pydantic import Field
|
||||
from starlette.datastructures import Headers
|
||||
from typing_extensions import Annotated
|
||||
import torch
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
|
||||
ChatTemplateContentFormatOption,
|
||||
ConversationMessage,
|
||||
apply_hf_chat_template,
|
||||
apply_mistral_chat_template,
|
||||
parse_chat_messages_futures,
|
||||
resolve_chat_template_content_format)
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
CompletionRequest,
|
||||
DetokenizeRequest,
|
||||
EmbeddingChatRequest,
|
||||
EmbeddingCompletionRequest,
|
||||
ErrorResponse, RerankRequest,
|
||||
ScoreRequest,
|
||||
TokenizeChatRequest,
|
||||
TokenizeCompletionRequest)
|
||||
# yapf: enable
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.entrypoints.openai.serving_engine import AnyRequest, TextTokensPrompt
|
||||
# from vllm.model_executor.sampling_metadata import _SAMPLING_EPS
|
||||
from vllm.v1.sample.sampler import _SAMPLING_EPS
|
||||
import os
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
from vllm_vacc.vllm.model_executor.models.vars import LLM_MAX_PREFILL_SEQ_LEN
|
||||
from vllm_vacc.vllm.model_executor.models.vars import CUT_PREFILL_SEQ_LEN
|
||||
|
||||
class EmbedsPrompt(TypedDict):
|
||||
prompt_embeds: torch.Tensor
|
||||
deepstack_input_embeds: Optional[dict]
|
||||
|
||||
class OpenAIServing:
|
||||
def _validate_input(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
input_ids: List[int],
|
||||
input_text: str,
|
||||
) -> TextTokensPrompt:
|
||||
# clint 设置的参数, 如果没有设, 还会再从 generation_config.json 读取
|
||||
if CUT_PREFILL_SEQ_LEN > 0 and CUT_PREFILL_SEQ_LEN < len(input_ids):
|
||||
cut_before = CUT_PREFILL_SEQ_LEN // 2
|
||||
cut_after = CUT_PREFILL_SEQ_LEN - cut_before
|
||||
input_ids = input_ids[:cut_before] + input_ids[(-1)*cut_after:]
|
||||
token_num = len(input_ids)
|
||||
|
||||
if not self.model_config.pooler_config:
|
||||
if (request.repetition_penalty is not None and abs(request.repetition_penalty - 1.0) >= _SAMPLING_EPS):
|
||||
raise ValueError(
|
||||
f"unsupport penalty for sampler"
|
||||
f"request.repetition_penalty: {request.repetition_penalty}; "
|
||||
f"Please remove penalty parameter in client and try again."
|
||||
)
|
||||
if request.min_p is not None and request.min_p > _SAMPLING_EPS:
|
||||
raise ValueError(f"unsupport min_p {request.min_p} for sampler")
|
||||
if request.prompt_logprobs is not None:
|
||||
raise ValueError(f"unsupport prompt_logprobs {request.prompt_logprobs} for sampler")
|
||||
|
||||
if request.min_p is not None and request.min_p > _SAMPLING_EPS:
|
||||
raise ValueError(f"unsupport min_p {request.min_p} for sampler")
|
||||
if request.prompt_logprobs is not None:
|
||||
raise ValueError(f"unsupport prompt_logprobs {request.prompt_logprobs} for sampler")
|
||||
|
||||
# model_type = self.model_config.hf_config.model_type
|
||||
# if model_type == "deepseek_v3":
|
||||
if token_num > LLM_MAX_PREFILL_SEQ_LEN:
|
||||
raise ValueError(
|
||||
f"This model's maximum input seq length limit is "
|
||||
f"{LLM_MAX_PREFILL_SEQ_LEN} tokens. However, you requested "
|
||||
f"({token_num} in the input messages, "
|
||||
f"Please reduce the length of the input messages.")
|
||||
|
||||
# Note: EmbeddingRequest and ScoreRequest doesn't have max_tokens
|
||||
if isinstance(request,
|
||||
(EmbeddingChatRequest, EmbeddingCompletionRequest,
|
||||
ScoreRequest, RerankRequest)):
|
||||
|
||||
operation = "score" if isinstance(request, ScoreRequest) \
|
||||
else "embedding generation"
|
||||
if token_num > self.max_model_len:
|
||||
raise ValueError(
|
||||
f"This model's maximum context length is "
|
||||
f"{self.max_model_len} tokens. However, you requested "
|
||||
f"{token_num} tokens in the input for {operation}. "
|
||||
f"Please reduce the length of the input.")
|
||||
return TextTokensPrompt(prompt=input_text,
|
||||
prompt_token_ids=input_ids)
|
||||
|
||||
# Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
|
||||
# and does not require model context length validation
|
||||
if isinstance(request, (TokenizeCompletionRequest, TokenizeChatRequest,
|
||||
DetokenizeRequest)):
|
||||
return TextTokensPrompt(prompt=input_text,
|
||||
prompt_token_ids=input_ids)
|
||||
|
||||
# chat completion endpoint supports max_completion_tokens
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
# TODO(#9845): remove max_tokens when field dropped from OpenAI API
|
||||
max_tokens = request.max_completion_tokens or request.max_tokens
|
||||
else:
|
||||
max_tokens = request.max_tokens
|
||||
if max_tokens is None:
|
||||
if token_num >= self.max_model_len:
|
||||
raise ValueError(
|
||||
f"This model's maximum context length is "
|
||||
f"{self.max_model_len} tokens. However, you requested "
|
||||
f"{token_num} tokens in the messages, "
|
||||
f"Please reduce the length of the messages.")
|
||||
elif token_num + max_tokens > self.max_model_len:
|
||||
raise ValueError(
|
||||
f"This model's maximum context length is "
|
||||
f"{self.max_model_len} tokens. However, you requested "
|
||||
f"{max_tokens + token_num} tokens "
|
||||
f"({token_num} in the messages, "
|
||||
f"{max_tokens} in the completion). "
|
||||
f"Please reduce the length of the messages or completion.")
|
||||
|
||||
|
||||
return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
|
||||
|
||||
def _log_inputs(
|
||||
self,
|
||||
request_id: str,
|
||||
inputs,
|
||||
params: Optional[Union[SamplingParams, PoolingParams,
|
||||
BeamSearchParams]],
|
||||
lora_request: Optional[LoRARequest],
|
||||
) -> None:
|
||||
# move to position where before use request_logger
|
||||
# if self.request_logger is None:
|
||||
# return
|
||||
# if self.model_config.pooler_config is not None, task is embedding , not generation task
|
||||
if self.model_config.pooler_config:
|
||||
return
|
||||
prompt, prompt_token_ids, prompt_embeds = None, None, None
|
||||
if isinstance(inputs, str):
|
||||
prompt = inputs
|
||||
elif isinstance(inputs, list):
|
||||
prompt_token_ids = inputs
|
||||
else:
|
||||
prompt = getattr(inputs, 'prompt', None)
|
||||
prompt_token_ids = getattr(inputs, 'prompt_token_ids', None)
|
||||
|
||||
# generation_config 读取的惩罚信息, 如果有,则警告并且修改
|
||||
if (params.repetition_penalty is not None and abs(params.repetition_penalty - 1.0) >= _SAMPLING_EPS):
|
||||
logger.warning(
|
||||
"\033[93mWARNING \033[0m"
|
||||
": Unsupport penalty for sampler"
|
||||
f"params.repetition_penalty: {params.repetition_penalty} and "
|
||||
"Please set attrs: extra_body = {\'repetition_penalty\': 1.0}\n"
|
||||
"Now set: repetition_penalty: 1.0"
|
||||
)
|
||||
# params.presence_penalty = 0
|
||||
# params.frequency_penalty = 0
|
||||
params.repetition_penalty = 1
|
||||
|
||||
if hasattr(params, "min_p") and params.min_p is not None and params.min_p > _SAMPLING_EPS:
|
||||
logger.warning(f"\033[93mWARNING \033[0m : unsupport min_p {params.min_p} for sampler")
|
||||
params.min_p = 0
|
||||
if self.request_logger is None:
|
||||
return
|
||||
self.request_logger.log_inputs(
|
||||
request_id,
|
||||
prompt,
|
||||
prompt_token_ids,
|
||||
prompt_embeds,
|
||||
params=params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
127
vllm_vacc/vllm/entrypoints/renderer.py
Normal file
127
vllm_vacc/vllm/entrypoints/renderer.py
Normal file
@@ -0,0 +1,127 @@
|
||||
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated, Optional, Union
|
||||
|
||||
import pybase64
|
||||
import torch
|
||||
from pydantic import Field
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt
|
||||
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
||||
from vllm.inputs.parse import parse_and_batch_prompt
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import AsyncMicrobatchTokenizer
|
||||
|
||||
|
||||
|
||||
class BaseRenderer(ABC):
|
||||
"""
|
||||
Base class for unified input processing and rendering.
|
||||
|
||||
The Renderer serves as a unified input processor that consolidates
|
||||
tokenization, chat template formatting, and multimodal input handling
|
||||
into a single component.
|
||||
It converts high-level API requests (OpenAI-style JSON) into token IDs and
|
||||
multimodal features ready for engine consumption.
|
||||
|
||||
Key responsibilities:
|
||||
- Convert text prompts to token sequences with proper special tokens
|
||||
- Apply chat templates and format conversations
|
||||
- Handle multimodal inputs (images, audio, etc.) when applicable
|
||||
- Manage prompt truncation and length validation
|
||||
- Provide clean separation between API layer and engine core
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def load_prompt_embeds(
|
||||
cls,
|
||||
prompt_embeds: Union[bytes, list[bytes]],
|
||||
deepstack_input_embeds: Optional[dict[str, Union[bytes, str]]] = None,
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=0)]] = None,
|
||||
cache_salt: Optional[str] = None,
|
||||
) -> list[EngineEmbedsPrompt]:
|
||||
"""Load and validate base64-encoded embeddings into prompt objects."""
|
||||
|
||||
def _load_and_validate_embed(embed: bytes) -> EngineEmbedsPrompt:
|
||||
tensor = torch.load(
|
||||
io.BytesIO(pybase64.b64decode(embed, validate=True)),
|
||||
weights_only=True,
|
||||
map_location=torch.device("cpu"),
|
||||
)
|
||||
assert isinstance(tensor, torch.Tensor) and tensor.dtype in (
|
||||
torch.float32,
|
||||
torch.bfloat16,
|
||||
torch.float16,
|
||||
)
|
||||
tensor = tensor.to_dense()
|
||||
if tensor.dim() > 2:
|
||||
tensor = tensor.squeeze(0)
|
||||
assert tensor.dim() == 2
|
||||
if truncate_prompt_tokens is not None:
|
||||
tensor = tensor[-truncate_prompt_tokens:]
|
||||
embeds_prompt = EngineEmbedsPrompt(prompt_embeds=tensor)
|
||||
if cache_salt is not None:
|
||||
embeds_prompt["cache_salt"] = cache_salt
|
||||
|
||||
if deepstack_input_embeds is not None:
|
||||
all_tensor = []
|
||||
from vllm.sequence import IntermediateTensors
|
||||
tensor_dict = torch.load(
|
||||
io.BytesIO(pybase64.b64decode(deepstack_input_embeds, validate=True))
|
||||
)
|
||||
for k in tensor_dict:
|
||||
all_tensor.append(tensor_dict[k].unsqueeze(0))
|
||||
|
||||
all_tensor = torch.concatenate(all_tensor, 0)
|
||||
embeds_prompt["deepstack_input_embeds"] = all_tensor #IntermediateTensors(tensors=tensor_dict)
|
||||
|
||||
return embeds_prompt
|
||||
|
||||
if isinstance(prompt_embeds, list):
|
||||
return [_load_and_validate_embed(embed) for embed in prompt_embeds]
|
||||
|
||||
return [_load_and_validate_embed(prompt_embeds)]
|
||||
|
||||
|
||||
|
||||
class CompletionRenderer(BaseRenderer):
|
||||
|
||||
async def render_prompt_and_embeds(
|
||||
self,
|
||||
*,
|
||||
prompt_or_prompts: Optional[Union[str, list[str], list[int],
|
||||
list[list[int]]]] = None,
|
||||
prompt_embeds: Optional[Union[bytes, list[bytes]]] = None,
|
||||
deepstack_input_embeds: Optional[Union[bytes, list[bytes]]] = None,
|
||||
config: "RenderConfig",
|
||||
) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]:
|
||||
"""
|
||||
Render text/token prompts and/or precomputed embedding prompts. At
|
||||
least one of `prompt_or_prompts` or `prompt_embeds` must be provided.
|
||||
"""
|
||||
truncate_prompt_tokens = self._validate_and_normalize_truncate_tokens(
|
||||
config.truncate_prompt_tokens, config.max_length)
|
||||
if truncate_prompt_tokens == 0:
|
||||
return []
|
||||
|
||||
rendered: list[Union[EngineTokensPrompt, EngineEmbedsPrompt]] = []
|
||||
|
||||
if prompt_embeds is not None:
|
||||
rendered.extend(
|
||||
self.load_prompt_embeds(prompt_embeds, deepstack_input_embeds, truncate_prompt_tokens,
|
||||
config.cache_salt))
|
||||
if prompt_or_prompts is None or prompt_or_prompts == "":
|
||||
return rendered
|
||||
|
||||
token_prompts = await self.render_prompt(
|
||||
prompt_or_prompts=prompt_or_prompts,
|
||||
config=config,
|
||||
)
|
||||
rendered.extend(token_prompts)
|
||||
|
||||
return rendered
|
||||
0
vllm_vacc/vllm/executor/__init__.py
Normal file
0
vllm_vacc/vllm/executor/__init__.py
Normal file
BIN
vllm_vacc/vllm/executor/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/executor/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
20
vllm_vacc/vllm/executor/executor_base.py
Normal file
20
vllm_vacc/vllm/executor/executor_base.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import asyncio
|
||||
from typing import List
|
||||
|
||||
from vllm.v1.outputs import PoolerOutput, SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
|
||||
# class DistributedExecutorBase():
|
||||
# """Abstract superclass of distributed executor implementations."""
|
||||
|
||||
async def execute_model_async(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||
if self.parallel_worker_tasks is None:
|
||||
# Start model execution loop running in the parallel workers
|
||||
self.parallel_worker_tasks = asyncio.create_task(
|
||||
self._start_worker_execution_loop())
|
||||
await asyncio.sleep(0)
|
||||
# Only the driver worker returns the sampling results.
|
||||
await asyncio.sleep(0)
|
||||
return await self._driver_execute_model_async(execute_model_req)
|
||||
0
vllm_vacc/vllm/inputs/__init__.py
Normal file
0
vllm_vacc/vllm/inputs/__init__.py
Normal file
BIN
vllm_vacc/vllm/inputs/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/inputs/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_vacc/vllm/inputs/__pycache__/data.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/inputs/__pycache__/data.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_vacc/vllm/inputs/__pycache__/preprocess.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/inputs/__pycache__/preprocess.cpython-312.pyc
Normal file
Binary file not shown.
55
vllm_vacc/vllm/inputs/data.py
Normal file
55
vllm_vacc/vllm/inputs/data.py
Normal file
@@ -0,0 +1,55 @@
|
||||
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, Union, cast
|
||||
|
||||
import torch
|
||||
from typing_extensions import NotRequired, TypedDict, TypeIs, TypeVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalInputs,
|
||||
MultiModalUUIDDict)
|
||||
|
||||
class EmbedsPrompt(TypedDict):
|
||||
"""Schema for a prompt provided via token embeddings."""
|
||||
|
||||
prompt_embeds: torch.Tensor
|
||||
"""The embeddings of the prompt."""
|
||||
from vllm.sequence import IntermediateTensors
|
||||
deepstack_input_embeds: Optional[IntermediateTensors]
|
||||
cache_salt: NotRequired[str]
|
||||
"""
|
||||
Optional cache salt to be used for prefix caching.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
class EmbedsInputs(TypedDict):
|
||||
"""Represents embeddings-based inputs."""
|
||||
|
||||
type: Literal["embeds"]
|
||||
"""The type of inputs."""
|
||||
|
||||
prompt_embeds: torch.Tensor
|
||||
"""The embeddings of the prompt."""
|
||||
deepstack_input_embeds: torch.Tensor
|
||||
|
||||
cache_salt: NotRequired[str]
|
||||
"""
|
||||
Optional cache salt to be used for prefix caching.
|
||||
"""
|
||||
|
||||
|
||||
def embeds_inputs(
|
||||
prompt_embeds: torch.Tensor,
|
||||
deepstack_input_embeds: torch.Tensor,
|
||||
cache_salt: Optional[str] = None,
|
||||
) -> EmbedsInputs:
|
||||
"""Construct [`EmbedsInputs`][vllm.inputs.data.EmbedsInputs] from optional
|
||||
values."""
|
||||
inputs = EmbedsInputs(type="embeds", prompt_embeds=prompt_embeds, deepstack_input_embeds=deepstack_input_embeds)
|
||||
|
||||
if cache_salt is not None:
|
||||
inputs["cache_salt"] = cache_salt
|
||||
|
||||
return inputs
|
||||
54
vllm_vacc/vllm/inputs/preprocess.py
Normal file
54
vllm_vacc/vllm/inputs/preprocess.py
Normal file
@@ -0,0 +1,54 @@
|
||||
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.multimodal.cache import BaseMultiModalProcessorCache
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
|
||||
MultiModalInputs, MultiModalUUIDDict)
|
||||
from vllm.multimodal.processing import BaseMultiModalProcessor
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
from .data import EmbedsInputs, EmbedsPrompt, embeds_inputs
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class InputPreprocessor:
|
||||
def _process_embeds(
|
||||
self,
|
||||
parsed_content: EmbedsPrompt,
|
||||
) -> EmbedsInputs:
|
||||
if not self.model_config.enable_prompt_embeds:
|
||||
raise ValueError("You must set `--enable-prompt-embeds` to input "
|
||||
"`prompt_embeds`.")
|
||||
|
||||
prompt_embeds = parsed_content["prompt_embeds"]
|
||||
deepstack_input_embeds = None
|
||||
if 'deepstack_input_embeds' in parsed_content:
|
||||
deepstack_input_embeds = parsed_content["deepstack_input_embeds"]
|
||||
|
||||
# prompt_embeds must be (seq_len, hidden_size), but if the user
|
||||
# passes in a batch of size 1, i.e. (1, seq_len, hidden_size),
|
||||
# we can unambiguously process the intent by squeezing the batch
|
||||
# dimension.
|
||||
if prompt_embeds.ndim == 3:
|
||||
prompt_embeds = prompt_embeds.squeeze(dim=0)
|
||||
|
||||
if prompt_embeds.ndim != 2:
|
||||
raise ValueError(
|
||||
"prompt_embeds must be of shape (seq_len, hidden_size).")
|
||||
|
||||
# Tensors must be on CPU for serialization between processes
|
||||
# in the MsgpackEncoder. Casting to CPU here ensures that there is no
|
||||
# hidden device transfer in the critical path of generation.
|
||||
prompt_embeds = prompt_embeds.cpu()
|
||||
|
||||
return embeds_inputs(prompt_embeds=prompt_embeds,
|
||||
deepstack_input_embeds=deepstack_input_embeds,
|
||||
cache_salt=parsed_content.get("cache_salt"))
|
||||
0
vllm_vacc/vllm/model_executor/__init__.py
Normal file
0
vllm_vacc/vllm/model_executor/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
38
vllm_vacc/vllm/model_executor/custom_op.py
Normal file
38
vllm_vacc/vllm/model_executor/custom_op.py
Normal file
@@ -0,0 +1,38 @@
|
||||
|
||||
import torch.nn as nn
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
class CustomOp(nn.Module):
|
||||
|
||||
def forward_vacc(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def dispatch_forward(self):
|
||||
# NOTE(woosuk): Here we assume that vLLM was built for only one
|
||||
# specific backend. Currently, we do not support dynamic dispatching.
|
||||
|
||||
enabled = self.enabled()
|
||||
logger.debug("custom op %s %s", self.__class__.name,
|
||||
"enabled" if enabled else "disabled")
|
||||
|
||||
if not enabled:
|
||||
return self.forward_native
|
||||
|
||||
return self.forward
|
||||
|
||||
if current_platform.is_rocm():
|
||||
return self.forward_hip
|
||||
elif current_platform.is_cpu():
|
||||
return self.forward_cpu
|
||||
elif current_platform.is_hpu():
|
||||
return self.forward_hpu
|
||||
elif current_platform.is_tpu():
|
||||
return self.forward_tpu
|
||||
elif current_platform.is_xpu():
|
||||
return self.forward_xpu
|
||||
elif current_platform.is_vacc():
|
||||
return self.forward
|
||||
else:
|
||||
return self.forward_cuda
|
||||
0
vllm_vacc/vllm/model_executor/layers/__init__.py
Normal file
0
vllm_vacc/vllm/model_executor/layers/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user