This commit is contained in:
2026-04-02 04:53:13 +00:00
parent 80932c96e5
commit 24df76db9d
1987 changed files with 447445 additions and 0 deletions

66
vllm_vacc/__init__.py Normal file
View File

@@ -0,0 +1,66 @@
# SPDX-License-Identifier: Apache-2.0
def generate_bootinfo():
from pathlib import Path
import datetime
import os
workspace_path = Path.cwd()
bootinfo_time = str(datetime.datetime.now())
bootinfo_tag_vccl = "[vccl]"
bootinfo_seq = "@@@"
bootinfo_config = f'{workspace_path}/.bootinfos'
bootinfo_inited = os.path.exists(bootinfo_config)
bootinfo_start_up_times = 0
if bootinfo_inited:
try:
with open(bootinfo_config) as w:
current_bootinfos = w.readline()
# print("current_bootinfos:",current_bootinfos)
bootinfo_start_up_times = int(current_bootinfos.split('cycle')[1].strip())
except Exception as e:
print("[WARN] read bootinfo fail, caused by ", e)
# limit max run times
if bootinfo_start_up_times > 10000000:
bootinfo_start_up_times = 0
current_bootinfos = f'{bootinfo_tag_vccl}{bootinfo_seq}{bootinfo_time} cycle {bootinfo_start_up_times + 1}'
try:
with open(bootinfo_config, 'w') as w:
w.write(current_bootinfos)
except Exception as e:
print("[WARN] write bootinfo fail, caused by ", e)
def exec_patching():
import sys
from vllm_vacc.vllm_patch import VllmPatchManager, patch_vllm, patch_vllm_v1, patch_torch, regist_mock_module
import sys
vpm = VllmPatchManager
if vpm.patched:
return
regist_mock_module()
patch_torch()
patch_vllm(vpm)
patch_vllm_v1(vpm)
vpm.apply_patches()
generate_bootinfo()
if "vllm.executor.executor_base" in sys.modules:
del sys.modules["vllm.executor.executor_base"]
def register():
return "vllm_vacc.platform.VaccPlatform"
def register_model():
exec_patching()
# from .models import register_model
# register_model()
return
# exec_patching()

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

299
vllm_vacc/patch_util.py Normal file
View File

@@ -0,0 +1,299 @@
import importlib
import sys
import pkgutil
import types
from typing import List
PATCH_ROOT = "vllm_vacc."
def get_func_name(func):
if isinstance(func, str):
return func
return ".".join((func.__module__, func.__qualname__))
def dummy_function_wrapper(func_name):
def dummy_function(*args, **kwargs):
raise RuntimeError(f"function {func_name} no exist")
return dummy_function
def dummy_jit(fn):
def wrapper(*args, **kwargs):
return fn(*args, **kwargs)
return wrapper
class Patch:
def __init__(self, orig_func_name, new_func, create_dummy):
split_name = orig_func_name.rsplit(".", 1)
if len(split_name) == 1:
self.orig_module_name, self.orig_func_name = orig_func_name, None
else:
self.orig_module_name, self.orig_func_name = split_name
self.orig_module = None
self.orig_func = None
self.patch_func = None
self.wrappers = []
if new_func is None:
new_func = dummy_function_wrapper(orig_func_name)
self.set_patch_func(new_func)
self.is_applied = False
self.create_dummy = create_dummy
@property
def orig_func_id(self):
return id(self.orig_func)
@property
def patch_func_id(self):
return id(self.patch_func)
def set_patch_func(self, new_func, force_patch=False):
if hasattr(new_func, "__name__") and new_func.__name__.endswith(
("wrapper", "decorator")
):
self.wrappers.append(new_func)
else:
if self.patch_func and not force_patch:
raise RuntimeError(
f"The patch of '{self.orig_func_name}' ('{self.patch_func}') exist!"
)
self.patch_func = new_func
self.is_applied = False
def apply_patch(self):
if self.is_applied:
return
self.orig_module, self.orig_func = Patch.parse_path(
self.orig_module_name, self.orig_func_name, self.create_dummy
)
if self.patch_func is None:
self.patch_func = self.orig_func
for wrapper in self.wrappers:
self.patch_func = wrapper(self.patch_func)
if self.orig_func_name is not None:
setattr(self.orig_module, self.orig_func_name, self.patch_func)
for key, value in sys.modules.copy().items():
# 遍历 pip 所有库, 然后 setattr, 有些库不匹配 可能会有问题, 这里是否可以优化 只遍历vllm相关
try:
if (
self.orig_func_name is not None
and hasattr(value, self.orig_func_name)
and id(getattr(value, self.orig_func_name)) == self.orig_func_id
):
setattr(value, self.orig_func_name, self.patch_func)
except:
continue
self.is_applied = True
@staticmethod
def parse_function(function_path: str, create_dummy):
split_name = function_path.rsplit(".", 1)
if len(split_name) == 1:
orig_module_name, orig_func_name = function_path, None
else:
orig_module_name, orig_func_name = split_name
return Patch.parse_path(orig_module_name, orig_func_name, create_dummy)[1]
@staticmethod
def parse_path(module_path, function_name, create_dummy):
from importlib.machinery import ModuleSpec
modules = module_path.split(".")
for i in range(1, len(modules) + 1):
parent = ".".join(modules[: i - 1])
path = ".".join(modules[:i])
try:
importlib.import_module(path)
except ModuleNotFoundError as e:
if not parent or not hasattr(
importlib.import_module(parent), modules[i - 1]
):
if not create_dummy:
raise ModuleNotFoundError(e) from e
sys.modules[path] = types.ModuleType(path)
sys.modules[path].__file__ = "patch_tools.dummy_module.py"
sys.modules[path].__spec__ = ModuleSpec(path, None)
if parent:
setattr(
importlib.import_module(parent),
modules[i - 1],
sys.modules[path],
)
else:
module = getattr(importlib.import_module(parent), modules[i - 1])
if hasattr(module, function_name):
return module, getattr(module, function_name)
elif create_dummy:
return module, dummy_function_wrapper(function_name)
else:
raise RuntimeError(
f"Function '{function_name}' of '{module}' does not exist."
) from e
if function_name is not None and not hasattr(
sys.modules[module_path], function_name
):
setattr(sys.modules[module_path], function_name, None)
return sys.modules[module_path], (
getattr(sys.modules[module_path], function_name)
if function_name is not None
else None
)
class PatchManager:
patches_info: dict = {}
patched: bool = False
@classmethod
def get_patch_info(cls):
return cls.patches_info
@classmethod
def register_patch(
cls,
orig_func_name,
new_func=None,
force_patch=False,
create_dummy=False,
allow_create=False,
):
"""
if new_func is written via @wraps, its name must be ended with `wrapper` or `decorator`,
also if it ends with `wrapper` or `decorator`, it must be written via `@wraps`
"""
if not cls._path_valid(orig_func_name):
if not allow_create:
raise ValueError(
f"Module/function path '{orig_func_name}' does not exist, and allow_create=False."
)
# if not create_dummy and not cls._path_valid(orig_func_name):
# print(f"WARNING: path '{orig_func_name}' is not valid, skipped.")
# return
patch_info = cls.get_patch_info()
if orig_func_name not in patch_info:
patch_info[orig_func_name] = Patch(orig_func_name, new_func, create_dummy)
else:
patch_info[orig_func_name].set_patch_func(new_func, force_patch)
@classmethod
def _is_module(self, module_path):
try:
importlib.import_module(module_path)
return True
except ModuleNotFoundError:
return False
@classmethod
def recursive_register_module(cls, module_path, allow_create=False):
# replace whole submodule in a module
assert cls._is_module(
PATCH_ROOT + module_path
), f'"{PATCH_ROOT}{module_path}" is not a valid module path. Only use this function to register module patches (not functions or classes).'
all_submodules = cls.enumerate_submodules(PATCH_ROOT + module_path)
all_submodules = [submodule[len(PATCH_ROOT) :] for submodule in all_submodules]
all_submodules = [module_path] + all_submodules
for module in all_submodules:
try:
importlib.import_module(module)
except ModuleNotFoundError as e:
pass
sys.modules[module] = importlib.import_module(PATCH_ROOT + module)
@classmethod
def batch_recursive_register_module(cls, module_paths, allow_create=False):
for module_path in module_paths:
cls.recursive_register_module(module_path, allow_create=allow_create)
@classmethod
def enumerate_submodules(cls, module_name):
try:
module = importlib.import_module(module_name)
except ImportError as e:
print(f"Error importing module {module_name}: {e}")
return []
submodules = []
for loader, submodule_name, is_pkg in pkgutil.walk_packages(
module.__path__, module.__name__ + "."
):
submodules.append(submodule_name)
return submodules
@classmethod
def batch_register_patch(
cls,
orig_func_names: List,
force_patch=False,
create_dummy=False,
allow_create=False,
):
"""
This function assumes all new_func are organized in same path like orig_func_names except prefixed with 'vastext.'
"""
for orig_func_name in orig_func_names:
if not cls._path_valid(orig_func_name) and not allow_create:
print(f"WARNING: path '{orig_func_name}' is not valid, skipped.")
continue
wrapper_name = orig_func_name
if cls._path_valid(PATCH_ROOT + wrapper_name + "_wrapper"):
wrapper_name = wrapper_name + "_wrapper"
assert cls._path_valid(
PATCH_ROOT + wrapper_name
), f"'{PATCH_ROOT}{wrapper_name}' or '{PATCH_ROOT}{wrapper_name}_wrapper' must be a valid module/function path. Try import {PATCH_ROOT}{wrapper_name} to see if other errors exist."
new_func = Patch.parse_function(
PATCH_ROOT + wrapper_name, create_dummy=False
)
if new_func is None:
new_func = importlib.import_module(PATCH_ROOT + wrapper_name)
# print(f">>> Register patch or function: '{orig_func_name}' -> '{new_func}'")
cls.register_patch(
orig_func_name,
new_func,
force_patch=force_patch,
create_dummy=create_dummy,
allow_create=allow_create,
)
@classmethod
def _path_valid(cls, path):
components = path.split(".")
for i in range(len(components), 0, -1):
module_name = ".".join(components[:i])
try:
module = importlib.import_module(module_name)
break
except ImportError:
continue
else:
return False
for component in components[i:]:
if hasattr(module, component):
module = getattr(module, component)
else:
return False
return True
@classmethod
def apply_patches(cls):
patch_info = cls.get_patch_info()
for patch in patch_info.values():
patch.apply_patch()
cls.patched = True

163
vllm_vacc/platform.py Normal file
View File

@@ -0,0 +1,163 @@
import os
from typing import TYPE_CHECKING, Optional
import torch
from vllm.logger import init_logger
# from .interface import Platform, PlatformEnum, _Backend
from vllm.platforms.interface import Platform, PlatformEnum, _Backend
if TYPE_CHECKING:
from vllm.config import VllmConfig,ModelConfig
else:
VllmConfig = None
ModelConfig = None
logger = init_logger(__name__)
class VaccPlatform(Platform):
try:
import torch_vacc
is_vacc = True
except Exception as e:
assert False, f"error import torch_vacc: {e}"
_enum = PlatformEnum.OOT
device_name: str = "vacc"
device_type: str = "vacc"
dispatch_key: str = "PrivateUse1"
ray_device_key: str = "GPU"
device_control_env_var: str = "VACC_VISIBLE_MODULES"
simple_compile_backend: str = "eager" # Disable torch.compile()
@classmethod
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool,
use_mla: bool, has_sink: bool, use_sparse: bool) -> str:
if use_mla:
logger.info("Using VACCMLA backend.")
if use_v1:
return "vllm_vacc.vllm.v1.attention.backends.vacc_mla.VACCMLABackend"
return "vllm_vacc.vllm.attention.backends.vacc_mla.VACCMLABackend"
if use_v1:
return "vllm_vacc.vllm.v1.attention.backends.vacc_attn.VACCAttentionBackend"
else:
logger.info("Using VACCAttention backend.")
return "vllm_vacc.vllm.attention.backends.vacc_attn.VACCAttentionBackend"
@classmethod
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
return True
@staticmethod
def inference_mode():
return torch.no_grad()
@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
import vllm.envs as envs
if vllm_config.kv_transfer_config:
raise NotImplementedError("kv-transfer-config is not implemented for VACC")
cache_config = vllm_config.cache_config
scheduler_config = vllm_config.scheduler_config
if ((scheduler_config.chunked_prefill_enabled
or cache_config.enable_prefix_caching)
and cache_config.cache_dtype != "auto"):
raise RuntimeError("Chunked-prefill and prefix-cache on the Vacc "
"backend is not compatible with FP8 KV cache.")
# scheduling_polity = scheduler_config.policy
# model_config = vllm_config.model_config
# use_async_output_proc = model_config.use_async_output_proc
# if scheduling_polity == "priority" and use_async_output_proc: # probably a bug
# logger.warning("WARNING scheduling_polity priority is not fully supported for VACC, "
# "use fcfs instead automatically")
# vllm_config.scheduler_config.scheduling_polity = "fcfs"
# if vllm_config.speculative_config is not None:
# raise NotImplementedError(
# "Speculative decoding is not implemented for VACC")
parallel_config = vllm_config.parallel_config
if parallel_config.worker_cls == "auto":
if vllm_config.speculative_config:
if envs.VLLM_USE_V1:
parallel_config.worker_cls = "vllm_vacc.vllm.v1.worker.vacc_worker.VACCWorker"
else:
parallel_config.worker_cls = "vllm.spec_decode.spec_decode_worker.create_spec_worker"
parallel_config.sd_worker_cls = "vllm_vacc.vllm.worker.vacc_worker.VACCWorker"
else:
if envs.VLLM_USE_V1:
parallel_config.worker_cls = \
"vllm_vacc.vllm.v1.worker.vacc_worker.VACCWorker"
print('v1 VACCWorker')
else:
parallel_config.worker_cls = \
"vllm_vacc.vllm.worker.vacc_worker.VACCWorker"
# NOTE(kzawora): default block size for Gaudi should be 128
# smaller sizes still work, but very inefficiently
cache_config = vllm_config.cache_config
if cache_config and cache_config.gpu_memory_utilization:
logger.warning("WARNING gpu_memory_utilization is not supported on VACC")
# if cache_config and cache_config.enable_prefix_caching:
# raise NotImplementedError("Prefix-caching is not implemented for VACC")
if cache_config and cache_config.block_size is None:
cache_config.block_size = 16
if (parallel_config.distributed_executor_backend == 'mp'
and envs.VLLM_WORKER_MULTIPROC_METHOD == 'fork'):
if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD",
None) is not None:
logger.warning("On VACC, VLLM_WORKER_MULTIPROC_METHOD=fork "
"might cause application hangs on exit. Using "
"VLLM_WORKER_MULTIPROC_METHOD=fork anyway, "
"as it was explicitly requested.")
else:
logger.warning(
"On VACC, VLLM_WORKER_MULTIPROC_METHOD=fork "
"might cause application hangs on exit. Setting "
"VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. "
"To override that behavior, please set "
"VLLM_WORKER_MULTIPROC_METHOD=fork explicitly.")
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
@classmethod
def is_pin_memory_available(cls):
logger.warning("Pin memory is not supported on VACC.")
return False
@classmethod
def get_punica_wrapper(cls) -> str:
return "vllm_vacc.vllm.lora.punica_wrapper.punica_vacc.PunicaWrapperVACC"
@classmethod
def get_current_memory_usage(cls,
device: Optional[torch.types.Device] = None
) -> float:
torch.vacc.reset_peak_memory_stats(device)
return torch.vacc.max_memory_allocated(device)
@classmethod
def use_all_gather(cls) -> bool:
"""
Whether to use allgather in LogitsProcessor to gather the logits.
"""
return True
@classmethod
def supports_v1(cls, model_config: ModelConfig) -> bool:
"""Returns whether the current platform can support v1 for the supplied
model configuration.
"""
# return False # or export VLLM_USE_V1=0 to use v0
if os.getenv("VLLM_USE_V1", 1) == '0':
return False
return True

View File

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,95 @@
import contextlib
import importlib
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import torch
import torch
import torch.library
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.scalar_type import ScalarType
def cutlass_scaled_mm_vacc(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
`cutlass_scaled_mm` implements a fused version of
`output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)`
where scale_a * a and scale_b * b are implemented using numpy-style
broadcasting.
In order to support blockwise scaling like found in DeepSeek V3 we also
support extended "group" broadcast rules. We extend the numpy-style
broadcasting rules with the following rule:
"if the extent of a dimension in the source shape is between 1 and
corresponding extent in the target shape we repeat each element along
that dimension src_shape[dim] // target_shape[dim] times consecutively"
example if we have:
a = [[1, 2], and target_shape = (2, 4)
[3, 4]]
then we would expand a to:
a = [[1, 1, 2, 2],
[3, 3, 4, 4]]
currently we only support the case:
scale_a.shape * [1, 128] == a.shape
scale_b.shape * [128, 128] == b.shape
"""
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
assert bias is None or bias.shape[0] == b.shape[
1] and bias.dtype == out_dtype
m = a.shape[0]
n = b.shape[1]
if current_platform.is_rocm():
triton_scaled_mm_module = importlib.import_module(
"vllm.model_executor.layers.quantization.compressed_tensors."
"triton_scaled_mm")
triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
# print('a',a.shape,a.dtype) # torch.Size([8192, 3584]) torch.float8_e4m3fn
# print('scale_a',scale_a.shape) #torch.Size([8192, 56])
# print('b',b.shape,b.dtype) # torch.Size([3584, 1536]) torch.float8_e4m3fn
# print('scale_b',scale_b.shape) #torch.Size([56, 12])
use_a32_w32 = True #反量化到fp32 计算 matmul
if use_a32_w32 or (b.shape[1]//scale_b.shape[1] != 128 or
a.shape[1]//scale_a.shape[1] != 128 or
b.shape[0]//scale_b.shape[0] != 128):
# cutlass_scaled_mm 不支持非128的 quant block
a1 = a.to(torch.float32).reshape(a.shape[0], scale_a.shape[1], -1)
scale_a = scale_a.reshape(scale_a.shape[0], scale_a.shape[1], 1).to(torch.float32)
a = (a1*scale_a).reshape(a.shape).contiguous()
b1 = b.to(torch.float32).reshape(scale_b.shape[0], b.shape[0]//scale_b.shape[0], scale_b.shape[1], b.shape[1]//scale_b.shape[1])
scale_b = scale_b.reshape(scale_b.shape[0], 1, scale_b.shape[1], 1).to(torch.float32)
b = (b1*scale_b).reshape(b.shape).contiguous()
out = a@b
if bias is not None:
out = out + bias
return out.to(out_dtype)
torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
return out
def concat_and_cache_mla(
kv_c: torch.Tensor,
k_pe: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
scale: torch.Tensor,
) -> None:
torch.vacc.concat_and_cache_attention(
kv_c, k_pe, kv_cache, slot_mapping)

View File

View File

@@ -0,0 +1,202 @@
import functools
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Tuple,
Type, TypeVar)
import torch
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import AttentionLayer
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
AttentionMetadata,
AttentionMetadataBuilder,
AttentionState, MLAAttentionImpl)
from vllm.attention.backends.mla.common import MLACommonMetadata,triton_attention
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, RowParallelLinear,
UnquantizedLinearMethod)
from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
from vllm.platforms import current_platform
try:
from vllm.vllm_flash_attn import flash_attn_varlen_func
is_vllm_fa = True
except ImportError:
is_vllm_fa = False
try:
# For rocm use upstream flash attention
from vllm.attention.backends.flash_attn import flash_attn_varlen_func
except ImportError:
flash_attn_varlen_func = None
T = TypeVar("T", bound="MLACommonMetadata")
class MLACommonImpl():
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
# blocksparse_params: Optional[Dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments
q_lora_rank: Optional[int],
kv_lora_rank: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
qk_head_dim: int,
v_head_dim: int,
rotary_emb: RotaryEmbedding,
# q_proj should be q_b_proj if q_lora_rank is not None, but from an
# attention backend perspective we rely on the layer to pass in the
# correct matrix
q_proj: ColumnParallelLinear,
kv_b_proj: ColumnParallelLinear,
o_proj: RowParallelLinear,
positions: torch.Tensor = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
self.kv_cache_dtype = kv_cache_dtype
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.qk_head_dim = qk_head_dim
self.v_head_dim = v_head_dim
self.rotary_emb = rotary_emb
self.use_yarn_rope = isinstance(rotary_emb,
DeepseekScalingRotaryEmbedding)
self.q_proj = q_proj
self.kv_b_proj = kv_b_proj
self.o_proj = o_proj
self.positions = positions
self.triton_fa_func = triton_attention
# Handle the differences between the flash_attn_varlen from flash_attn
# and the one from vllm_flash_attn. The former is used on RoCM and the
# latter has an additional parameter to control FA2 vs FA3
self.flash_attn_varlen_func = flash_attn_varlen_func
self.vllm_flash_attn_version = get_flash_attn_version()
if self.vllm_flash_attn_version is not None:
self.flash_attn_varlen_func = \
functools.partial(flash_attn_varlen_func,
fa_version=self.vllm_flash_attn_version)
# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim for attention backends that do
# not support different headdims
# We don't need to pad V if we are on a hopper system with FA3
self._pad_v = self.vllm_flash_attn_version is None or not (
self.vllm_flash_attn_version == 3
and current_platform.get_device_capability()[0] == 9)
def forward(
self,
layer: AttentionLayer,
hidden_states_or_q_c: torch.Tensor, # query in unified attn
k_c_normed: torch.Tensor, # key in unified attn
k_pe: torch.Tensor, # value in unified attn
kv_cache: torch.Tensor,
attn_metadata: T,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if output is not None:
raise NotImplementedError(
"output is not yet supported for MLAImplBase")
# if attn_metadata.is_profile_run and \
# attn_metadata.context_chunk_workspace is not None:
# # During the profile run try to simulate to worse case output size
# # for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context`
# # since this can be large
# _ = torch.empty(
# (attn_metadata.context_chunk_workspace.shape[0],
# self.num_heads, self.qk_nope_head_dim + self.v_head_dim),
# device=k_c_normed.device,
# dtype=k_c_normed.dtype,
# )
has_decode = attn_metadata.decode_metadata is not None
has_prefill = attn_metadata.prefill_metadata is not None
# Restore head dim (for rotary embedding)
k_pe = k_pe.unsqueeze(1)
# assert hasattr(attn_metadata, "input_positions")
if self.positions is not None:
positions = self.positions
elif hasattr(attn_metadata, "input_positions"):
positions = attn_metadata.input_positions
else:
raise ValueError('no positions')
num_prefill_tokens: int = attn_metadata.num_prefill_tokens
decode_hs_or_q_c = hidden_states_or_q_c[num_prefill_tokens:]
decode_k_pe = k_pe[num_prefill_tokens:]
decode_input_positions = \
positions[num_prefill_tokens:]
prefill_hs_or_q_c = hidden_states_or_q_c[:num_prefill_tokens]
prefill_k_pe = k_pe[:num_prefill_tokens]
prefill_input_positions = \
positions[:num_prefill_tokens]
prefill_k_c_normed = k_c_normed[:num_prefill_tokens]
if has_decode:
decode_ql_nope, decode_q_pe = \
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
decode_input_positions, decode_q_pe, decode_k_pe)
if has_prefill:
prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\
.view(-1, self.num_heads, self.qk_head_dim)
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
prefill_input_positions, prefill_q_pe, prefill_k_pe)
# write the latent and rope to kv cache
if kv_cache.numel() > 0:
ops.concat_and_cache_mla(
k_c_normed,
k_pe.squeeze(1),
kv_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache_dtype=self.kv_cache_dtype,
scale=layer._k_scale,
)
# output = torch.empty(attn_metadata.num_prefill_tokens +
# attn_metadata.num_decode_tokens,
# self.o_proj.output_size,
# device=hidden_states_or_q_c.device,
# dtype=hidden_states_or_q_c.dtype)
if has_prefill:
return self._forward_prefill(
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
attn_metadata)
if has_decode:
return self._forward_decode(
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata)
assert False, "mla forward need prefill or decode function"
return None

View File

@@ -0,0 +1,390 @@
from abc import abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Generic, List, Optional, Tuple
import torch
from compressed_tensors.quantization import QuantizationStrategy
from vllm import _custom_ops as ops
from vllm import envs
from vllm.attention.backends.abstract import (AttentionLayer,
AttentionMetadata,
MLAAttentionImpl, T)
from vllm.distributed import (get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, RowParallelLinear,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsLinearMethod)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsW8A8Fp8)
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
# from vllm.model_executor.layers.quantization.utils.fp8_utils import (
# apply_fp8_linear_generic, current_platform_fp8_dtype, is_fp8)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
scaled_dequantize, scaled_quantize)
import os
W_Q_W_QR_WUV_WUK_USE_FP8 = True
class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
def process_weights_after_loading(self, act_dtype: torch.dtype):
def is_layer_fp8(layer: LinearBase) -> bool:
return isinstance(layer.quant_method, Fp8LinearMethod) or\
(isinstance(layer.quant_method, CompressedTensorsLinearMethod)\
and isinstance(layer.scheme, CompressedTensorsW8A8Fp8))
def quantization_scheme_supported(layer: LinearBase) -> bool:
return isinstance(layer.quant_method, UnquantizedLinearMethod) or \
is_layer_fp8(layer)
# TODO(lucas) This is very gross, we need a more wide scale refactor of
# all the FP8 code with a more standard way of
# defining schemes/group-shapes, we should also potentially force
# quant_methods to support a decompress function
#
# returns input_group_shape, weight_group_shape
def get_scale_group_shapes_for_fp8(layer: LinearBase) -> \
Tuple[Tuple[int, int], Tuple[int, int]]:
if isinstance(layer.quant_method, Fp8LinearMethod):
if layer.quant_method.block_quant is not None:
weight_block_size = \
layer.quant_method.quant_config.weight_block_size
# per-token-group (1, X), block-quantized (X, Y)
return (1, weight_block_size[-1]), weight_block_size
else:
return (-1, -1), (-1, -1) # per-tensor, per-tensor
elif isinstance(layer.quant_method, CompressedTensorsLinearMethod)\
and isinstance(layer.scheme, CompressedTensorsW8A8Fp8):
# this is hacky but we always assume the for
# CompressedTensorsW8A8Fp8 the input is dynamic per-token
# we ignore if it is static-per-tensor since we are going to
# requantize after later anyways
strategy = layer.scheme.strategy
if strategy == QuantizationStrategy.TENSOR:
return (1, -1), (-1, -1) # per-token, per-tensor
elif strategy == QuantizationStrategy.CHANNEL:
return (1, -1), (-1, 1) # per-token, per-channel
else:
raise NotImplementedError(
f"QuantizationStrategy.{strategy} is not supported for "
"fp8 MLA, please run with VLLM_MLA_DISABLE=1")
else:
raise NotImplementedError(
"Can't determine scale group shapes for "
f"{layer.quant_method}, please run with VLLM_MLA_DISABLE=1"
)
def get_scales(layer: LinearBase) -> torch.Tensor:
if hasattr(layer, "weight_scale_inv"):
return layer.weight_scale_inv
return layer.weight_scale
def get_fp8_layer_weight(layer: LinearBase):
if is_layer_fp8(layer):
if isinstance(layer.quant_method, \
CompressedTensorsLinearMethod) and \
isinstance(layer.scheme, CompressedTensorsW8A8Fp8):
# NOTE(lucas): note sure why but `CompressedTensorsW8A8Fp8`
# seems to store weights as (input, output) instead of
# (output, input) so we need to transpose
weight = layer.weight.T # standardize to (output, input)
else:
weight = layer.weight
_, weight_scale_group_shape = \
get_scale_group_shapes_for_fp8(layer)
scales = get_scales(layer) # 已经expand过了
weight_scale_group_shape=weight_scale_group_shape.copy() #config中读出来的[128,128], 需要 .copy(), 否则会把config改掉
# 重新校准一下 weight_scale_group_shape
if weight.shape[0] // scales.shape[0] != weight_scale_group_shape[0]:
weight_scale_group_shape[0] = weight.shape[0] // scales.shape[0]
if weight.shape[1] // scales.shape[1] != weight_scale_group_shape[1]:
weight_scale_group_shape[1] = weight.shape[1] // scales.shape[1]
return weight, scales
else:
return layer.weight, None
def get_fp8_layer_weight_test(layer: LinearBase):
if is_layer_fp8(layer):
if isinstance(layer.quant_method, \
CompressedTensorsLinearMethod) and \
isinstance(layer.scheme, CompressedTensorsW8A8Fp8):
# NOTE(lucas): note sure why but `CompressedTensorsW8A8Fp8`
# seems to store weights as (input, output) instead of
# (output, input) so we need to transpose
weight = layer.weight.T # standardize to (output, input)
else:
weight = layer.weight
_, weight_scale_group_shape = \
get_scale_group_shapes_for_fp8(layer)
scales = get_scales(layer) # 已经expand过了
weight_scale_group_shape=weight_scale_group_shape.copy() #config中读出来的[128,128], 需要 .copy(), 否则会把config改掉
# 重新校准一下 weight_scale_group_shape
if weight.shape[0] // scales.shape[0] != weight_scale_group_shape[0]:
weight_scale_group_shape[0] = weight.shape[0] // scales.shape[0]
if weight.shape[1] // scales.shape[1] != weight_scale_group_shape[1]:
weight_scale_group_shape[1] = weight.shape[1] // scales.shape[1]
# for test
weight = scaled_dequantize(weight, scales, weight_scale_group_shape)
# print(f'{weight.shape}, {scales.shape}, {weight_scale_group_shape}')
return weight, scales
else:
return layer.weight, None
def check_eq(name, tensor0, tensor1):
assert tensor0.shape == tensor1.shape
isEqual = torch.equal(tensor0.reshape([-1]).float(), tensor1.reshape([-1]).float())
print(f"{os.getpid()} check {name} {tensor0.shape} equal: {isEqual}")
return isEqual
def get_and_maybe_dequant_weights(layer: LinearBase):
if is_layer_fp8(layer):
if isinstance(layer.quant_method, \
CompressedTensorsLinearMethod) and \
isinstance(layer.scheme, CompressedTensorsW8A8Fp8):
# NOTE(lucas): note sure why but `CompressedTensorsW8A8Fp8`
# seems to store weights as (input, output) instead of
# (output, input) so we need to transpose
weight = layer.weight.T # standardize to (output, input)
else:
weight = layer.weight
_, weight_scale_group_shape = \
get_scale_group_shapes_for_fp8(layer)
scales = get_scales(layer) # 已经expand过了
weight_scale_group_shape=weight_scale_group_shape.copy() #config中读出来的[128,128], 需要 .copy(), 否则会把config改掉
# 重新校准一下 weight_scale_group_shape
if weight.shape[0] // scales.shape[0] != weight_scale_group_shape[0]:
weight_scale_group_shape[0] = weight.shape[0] // scales.shape[0]
if weight.shape[1] // scales.shape[1] != weight_scale_group_shape[1]:
weight_scale_group_shape[1] = weight.shape[1] // scales.shape[1]
return scaled_dequantize(weight, scales,
weight_scale_group_shape)
else:
return layer.weight
if not (quantization_scheme_supported(self.kv_b_proj) and\
quantization_scheme_supported(self.q_proj) and\
quantization_scheme_supported(self.o_proj)):
raise NotImplementedError(
"Only FP8 and UnquantizedLinearMethod are supported for MLA"
", please run with VLLM_MLA_DISABLE=1")
weight_dtype = self.kv_b_proj.weight.dtype
assert self.o_proj.weight.dtype == weight_dtype
assert self.q_proj.weight.dtype == weight_dtype
if W_Q_W_QR_WUV_WUK_USE_FP8: #and not envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
# 512,1024(=4x256)
kv_b_proj_weight, kv_b_proj_scale = \
[t.T for t in get_fp8_layer_weight(self.kv_b_proj)]
# kv_b_proj_weight = kv_b_proj_weight.transpose(-1,-2).contiguous().transpose(-1,-2)
N, K = kv_b_proj_weight.shape[0], kv_b_proj_weight.shape[1]
# 512,1024 => 512,4,256
kv_b_proj_weight = kv_b_proj_weight.view(
self.kv_lora_rank,
self.num_heads,
self.qk_nope_head_dim + self.v_head_dim,
)
kv_b_proj_scale = kv_b_proj_scale.view(
kv_b_proj_scale.shape[0] * self.kv_lora_rank // N,
self.num_heads,
kv_b_proj_scale.shape[1] * N // (self.kv_lora_rank * self.num_heads),
)
W_UK, W_UV = kv_b_proj_weight.split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
W_UK = W_UK.contiguous()
scale_0 = kv_b_proj_scale.shape[-1] * self.qk_nope_head_dim // (self.qk_nope_head_dim + self.v_head_dim)
scale_1 = kv_b_proj_scale.shape[-1] - scale_0
W_UK_scale, W_UV_scale = kv_b_proj_scale.split(
[scale_0, scale_1], dim=-1)
W_UK_scale = W_UK_scale.view(W_UK_scale.shape[0], -1).unsqueeze(-1).contiguous()
W_UV_scale = W_UV_scale.view(W_UV_scale.shape[0], -1).unsqueeze(-1)
# weight: [1536, 768] scale: 12,6
q_proj_weight, q_proj_scale = \
[t.T for t in get_fp8_layer_weight(self.q_proj)]
#self.W_Q_QR = q_proj_weight.contiguous().transpose(-2,-1).contiguous().transpose(-2,-1)
#self.W_Q_QR_scales = q_proj_scale.reshape(12, 6, 1).repeat(1, 1, 4).reshape(12, -1).contiguous().transpose(-2,-1).contiguous().transpose(-2,-1)
q_proj_weight = q_proj_weight\
.view(-1, self.num_heads, self.qk_head_dim)
# w_q[1536, 512] + w_qr[1536, 256]
W_Q = q_proj_weight[..., :self.qk_nope_head_dim].flatten(start_dim=1)
W_QR = q_proj_weight[..., self.qk_nope_head_dim:]\
.flatten(start_dim=1).contiguous()
# w_q_scale 12,16 + w_qr_scale 12,8
# expand: 12,6(4+2) -> 12,24(16+8)
# Q_scale: [s0x4, s1x2, s2x2, s3x4, s4x2, s5x2]
repeat_pattern = torch.tensor([4, 2, 2, 4, 2, 2], device=q_proj_scale.device)
W_Q_scale = torch.repeat_interleave(q_proj_scale, repeat_pattern, dim=1)
# Q_R_scale: [s1x2, s2x2, s4x2, s5x2]
selected_indices = [1, 2, 4, 5]
repeat_times = 2
selected = q_proj_scale[:, selected_indices]
W_QR_scale = selected.repeat_interleave(repeat_times, dim=1)
# temp_WQ_Scale = W_Q_scale.reshape(12, 4, -1).contiguous()
# temp_W_QR_scale = W_QR_scale.reshape(12, 4, -1).contiguous()
# temp_scale = torch.cat([temp_WQ_Scale, temp_W_QR_scale], dim=2).contiguous().reshape(12, -1).contiguous().transpose(-2,-1).contiguous().transpose(-2,-1)
# self.W_Q_QR_scales = temp_scale
# print("W_Q_scale:", W_Q_scale.shape)
# print("W_QR_scale:", W_QR_scale.shape)
# print("temp_scale:", temp_scale.shape)
# exit(0)
# Note: to be vnnl compatible
# 1. expand w_uv scale for core split friendly
if W_UV.shape[-1] % 4 == 0:
W_UV_scale = W_UV_scale.expand((W_UV_scale.shape[0], W_UV_scale.shape[1], 4))
# 2. change w_q, w_qr, w_uv weight&scale to K-contiguous (shape unchanged)
W_Q = W_Q.transpose(-2,-1).contiguous().transpose(-2,-1)
W_Q_scale = W_Q_scale.transpose(-2,-1).contiguous().transpose(-2,-1)
W_QR = W_QR.transpose(-2,-1).contiguous().transpose(-2,-1)
W_QR_scale = W_QR_scale.transpose(-2,-1).contiguous().transpose(-2,-1)
W_UV = W_UV.permute(2,1,0).contiguous().permute(2,1,0)
W_UV_scale = W_UV_scale.permute(2,1,0).contiguous().permute(2,1,0)
self.W_Q = W_Q
self.W_Q_scales = W_Q_scale
self.W_QR = W_QR
self.W_QR_scales = W_QR_scale
# temp_Q_scale = self.W_Q_scales.contiguous()
# temp_W_QR_scale = self.W_QR_scales.contiguous()
# self.W_Q_QR = q_proj_weight.reshape(1536, -1).contiguous().transpose(-2,-1).contiguous().transpose(-2,-1)
# self.W_Q_QR_scales = torch.concat([temp_Q_scale,temp_W_QR_scale],dim=1).contiguous().transpose(-2,-1).contiguous().transpose(-2,-1)
#self.W_Q_QR = torch.concat([self.W_Q.contiguous(),self.W_QR.contiguous()],dim=1).contiguous().transpose(-2,-1).contiguous().transpose(-2,-1)
#self.W_Q_QR_scales = torch.concat([W_Q_scale,W_QR_scale],dim=1).contiguous().transpose(-2,-1).contiguous().transpose(-2,-1)
self.W_UV = W_UV
self.W_UV_scales = W_UV_scale
self.W_UK = W_UK
self.W_UK_scales = W_UK_scale
return
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
assert kv_b_proj_weight.shape == (
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
f"{kv_b_proj_weight.shape=}, "
f"{self.kv_lora_rank=}, "
f"{self.num_heads=}, "
f"{self.qk_nope_head_dim=}, "
f"{self.v_head_dim=}")
kv_b_proj_weight = kv_b_proj_weight.view(
self.kv_lora_rank,
self.num_heads,
self.qk_nope_head_dim + self.v_head_dim,
)
W_UK, W_UV = kv_b_proj_weight.split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
q_proj_weight = get_and_maybe_dequant_weights(self.q_proj).T\
.view(-1, self.num_heads, self.qk_head_dim)
# can be W_Q or W_UQ depending q_lora_rank, the former if
# q_lora_rank is None, the latter otherwise. From the Attention backend
# perspective though we call these both W_Q and rely on the layer
# to pass in the correct matrix
W_Q = q_proj_weight[..., :self.qk_nope_head_dim]
self.W_QR = q_proj_weight[..., self.qk_nope_head_dim:]\
.flatten(start_dim=1).contiguous()
# W_QR is small so for simplicity we dont bother requantizing it
self.W_QR = self.W_QR.to(act_dtype)
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
assert False, "please set VLLM_MLA_PERFORM_MATRIX_ABSORPTION=0"
requantization_enabled = not envs.VLLM_MLA_DISABLE_REQUANTIZATION
if is_fp8(weight_dtype) and requantization_enabled:
# This assumes it wise to requantize using the same group shapes
# (i.e. strategy, per-tensor, per-channel, block etc.) that the
# weights were originally quantized
requant_input_group_shape, requant_weight_group_shape = \
get_scale_group_shapes_for_fp8(self.q_proj)
assert (requant_input_group_shape, requant_weight_group_shape)\
== get_scale_group_shapes_for_fp8(self.kv_b_proj)
assert (requant_input_group_shape, requant_weight_group_shape)\
== get_scale_group_shapes_for_fp8(self.o_proj)
self.reqaunt_input_group_shape = requant_input_group_shape
self.reqaunt_weight_group_shape = requant_weight_group_shape
#
# Perform matrix-absorption following
# https://github.com/flashinfer-ai/flashinfer/pull/551
# for decode, as a result we end up with absorbed weights for decode
# and another copy of raw weights for prefill.
#
self.W_UK, self.W_UV = kv_b_proj_weight.split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
# We absorb `W_UK` into `W_Q` resulting in either W_Q_UK or W_UQ_UK
# depending q_lora_rank, the former if q_lora_rank is None, the
# latter otherwise
# basically if q_lora_rank is none we are absorbing into q_proj
# instead of UQ
W_Q_UK = torch.einsum("qnd,lnd -> qnl", W_Q, W_UK)\
.flatten(start_dim=1).contiguous()
if is_fp8(weight_dtype) and requantization_enabled:
W_Q_UK, W_Q_UK_scales = scaled_quantize(
W_Q_UK,
self.reqaunt_weight_group_shape,
quant_dtype=current_platform_fp8_dtype)
# For FP8 save the transpose so we can use
# `apply_w8a8_block_fp8_linear` directly
self.W_Q_UK = W_Q_UK.T.contiguous()
self.W_Q_UK_scales = W_Q_UK_scales.T.contiguous()
else:
self.W_Q_UK = W_Q_UK.to(act_dtype)
W_O = get_and_maybe_dequant_weights(self.o_proj)\
.view(-1, self.num_heads, self.v_head_dim)
W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\
.flatten(start_dim=0, end_dim=1).contiguous()
if is_fp8(weight_dtype) and requantization_enabled:
W_UV_O, W_UV_O_scales = scaled_quantize(
W_UV_O,
self.reqaunt_weight_group_shape,
quant_dtype=current_platform_fp8_dtype)
# For FP8 save the transpose so we can use
# `apply_w8a8_block_fp8_linear` directly
self.W_UV_O = W_UV_O.T.contiguous()
self.W_UV_O_scales = W_UV_O_scales.T.contiguous()
else:
self.W_UV_O = W_UV_O.to(act_dtype)
self.tp_size = get_tensor_model_parallel_world_size()
else:
# print('W_UV', W_UV.dtype) #float32
#if is_fp8(weight_dtype):
# raise NotImplementedError(
# "Currently fp8 requires matrix absorption")
# self.W_UV = W_UV
# self.W_UK = W_UK
self.W_UV = W_UV.to(act_dtype) # fp32 to bfp16
self.W_UK = W_UK.to(act_dtype)
W_Q = W_Q.to(act_dtype)
self.W_Q = W_Q.flatten(start_dim=1)

View File

@@ -0,0 +1,726 @@
""" Attention layer with torch scaled_dot_product_attention
and PagedAttention."""
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type
import torch
from torch.nn.functional import scaled_dot_product_attention
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
# from vllm.attention.backends.utils import CommonAttentionState
# from vllm.attention.ops.paged_attn import PagedAttentionMetadata
from vllm_vacc.vllm.attention.ops.vacc_paged_attn import VaccPagedAttention as PagedAttention
# from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
# AttentionLayer,
# AttentionMetadata,
# AttentionMetadataBuilder,
# AttentionType)
# from vllm.attention.backends.utils import CommonAttentionState
# from vllm.attention.ops.ipex_attn import PagedAttention
from vllm.attention.ops.paged_attn import PagedAttentionMetadata
from vllm.logger import init_logger
from vllm.utils import make_tensor_with_pad
from vllm_vacc.vllm.v1.worker.vacc_model_runner import ModelInputForVACCBuilder
logger = init_logger(__name__)
import os
from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size
class VACCAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "TORCH_VACC"
@staticmethod
def get_impl_cls() -> Type["VACCAttentionBackendImpl"]:
return VACCAttentionBackendImpl
@staticmethod
def get_metadata_cls() -> Type["AttentionMetadata"]:
return VACCAttentionMetadata
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def get_builder_cls() -> Type["VACCMetadataBuilder"]:
return VACCMetadataBuilder
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
num_kv_heads, head_size)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)
@dataclass
class VACCAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
"""Metadata for VACCAttentionMetadata.
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
chunked_prefill: bool
seq_lens: Optional[List[int]] = None # For non-chunked prefill
# For chunked prefill only
max_query_len: Optional[int] = None
max_kv_len: Optional[int] = None
query_start_loc: Optional[torch.Tensor] = None
kv_start_loc: Optional[torch.Tensor] = None
prefill_block_tables: Optional[torch.Tensor] = None
# Begin encoder attn & enc/dec cross-attn fields...
# Encoder sequence lengths representation
encoder_seq_lens: Optional[List[int]] = None
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
# Maximum sequence length among encoder sequences
max_encoder_seq_len: Optional[int] = None
# Number of tokens input to encoder
num_encoder_tokens: Optional[int] = None
# Cross-attention memory-mapping data structures: slot mapping
# and block tables
cross_slot_mapping: Optional[torch.Tensor] = None
cross_block_tables: Optional[torch.Tensor] = None
def __post_init__(self):
# Set during the execution of the first attention op.
# It is a list because it is needed to set per prompt
# when alibi slopes is used. It is because of the limitation
# from xformer API.
# will not appear in the __repr__ and __init__
self.attn_bias: Optional[List[torch.Tensor]] = None
self.encoder_attn_bias: Optional[List[torch.Tensor]] = None
self.cross_attn_bias: Optional[List[torch.Tensor]] = None
@property
def is_all_encoder_attn_metadata_set(self):
'''
All attention metadata required for encoder attention is set.
'''
return ((self.encoder_seq_lens is not None)
and (self.encoder_seq_lens_tensor is not None)
and (self.max_encoder_seq_len is not None))
@property
def is_all_cross_attn_metadata_set(self):
'''
All attention metadata required for enc/dec cross-attention is set.
Superset of encoder attention required metadata.
'''
return (self.is_all_encoder_attn_metadata_set
and (self.cross_slot_mapping is not None)
and (self.cross_block_tables is not None))
@property
def prefill_metadata(self) -> Optional["VACCAttentionMetadata"]:
# Currently chunked prefill is not supported
if self.num_prefill_tokens == 0:
return None
return self
@property
def decode_metadata(self) -> Optional["VACCAttentionMetadata"]:
# Currently chunked prefill is not supported
if self.num_decode_tokens == 0:
return None
return self
def get_seq_lens(
self,
attn_type: AttentionType,
):
'''
Extract appropriate sequence lengths from attention metadata
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate sequence lengths tensor for query
* Appropriate sequence lengths tensor for key & value
'''
if (attn_type == AttentionType.DECODER
or attn_type == AttentionType.ENCODER_ONLY):
seq_lens_q = self.seq_lens
seq_lens_kv = self.seq_lens
elif attn_type == AttentionType.ENCODER:
seq_lens_q = self.encoder_seq_lens
seq_lens_kv = self.encoder_seq_lens
elif attn_type == AttentionType.ENCODER_DECODER:
seq_lens_q = self.seq_lens
seq_lens_kv = self.encoder_seq_lens
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
return seq_lens_q, seq_lens_kv
def get_attn_bias(
self,
attn_type: AttentionType,
) -> Optional[List[torch.Tensor]]:
'''
Extract appropriate attention bias from attention metadata
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate attention bias value given the attention type
'''
if (attn_type == AttentionType.DECODER
or attn_type == AttentionType.ENCODER_ONLY):
return self.attn_bias
elif attn_type == AttentionType.ENCODER:
return self.encoder_attn_bias
elif attn_type == AttentionType.ENCODER_DECODER:
return self.cross_attn_bias
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
def set_attn_bias(
self,
attn_bias: List[torch.Tensor],
attn_type: AttentionType,
) -> None:
'''
Update appropriate attention bias field of attention metadata,
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_bias: The desired attention bias value
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
'''
if (attn_type == AttentionType.DECODER
or attn_type == AttentionType.ENCODER_ONLY):
self.attn_bias = attn_bias
elif attn_type == AttentionType.ENCODER:
self.encoder_attn_bias = attn_bias
elif attn_type == AttentionType.ENCODER_DECODER:
self.cross_attn_bias = attn_bias
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
def get_seq_len_block_table_args(
self,
attn_type: str,
) -> tuple:
'''
The particular choice of sequence-length- and block-table-related
attributes which should be extracted from attn_metadata is dependent
on the type of attention operation.
Decoder attn -> select entirely decoder self-attention-related fields
Encoder/decoder cross-attn -> select encoder sequence lengths &
cross-attn block-tables fields
Encoder attn -> select encoder sequence lengths fields & no block tables
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* is_prompt: True if prefill, False otherwise
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate sequence-lengths tensor
* Appropriate max sequence-length scalar
* Appropriate block tables (or None)
'''
if (attn_type == AttentionType.DECODER
or attn_type == AttentionType.ENCODER_ONLY):
# Decoder self-attention
# Choose max_seq_len based on whether we are in prompt_run
return (self.seq_lens_tensor, self.max_decode_seq_len,
self.block_tables)
elif attn_type == AttentionType.ENCODER_DECODER:
# Enc/dec cross-attention KVs match encoder sequence length;
# cross-attention utilizes special "cross" block tables
return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len,
self.cross_block_tables)
elif attn_type == AttentionType.ENCODER:
# No block tables associated with encoder attention
return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len,
None)
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
class VACCMetadataBuilder(AttentionMetadataBuilder[VACCAttentionMetadata]):
def __init__(self, input_builder: ModelInputForVACCBuilder) -> None:
self.chunked_prefill = input_builder.chunked_prefill
self.input_builder = input_builder
def prepare(self):
self.input_data = self.input_builder.input_data
def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int) -> VACCAttentionMetadata:
input_data = self.input_data
prefill_seq_lens = seq_lens[0:input_data.num_prefills]
prefill_query_lens = query_lens[0:input_data.num_prefills]
slot_mapping = torch.tensor(input_data.slot_mapping,
dtype=torch.int32,
device=self.input_builder.device)
# For chunked-prefill
if self.chunked_prefill and input_data.num_prefill_tokens != 0:
prefill_block_tables = make_tensor_with_pad(
self.input_data.prefill_block_tables,
pad=0,
dtype=torch.int32,
device=self.input_builder.device,
)
query_lens_tensor = torch.tensor(prefill_query_lens,
dtype=torch.int32,
device=self.input_builder.device)
kv_lens_tensor = torch.tensor(prefill_seq_lens,
dtype=torch.int32,
device=self.input_builder.device)
query_start_loc = torch.zeros(input_data.num_prefills + 1,
dtype=torch.int32,
device=self.input_builder.device)
kv_start_loc = torch.zeros(input_data.num_prefills + 1,
dtype=torch.int32,
device=self.input_builder.device)
torch.cumsum(query_lens_tensor,
dim=0,
dtype=torch.int32,
out=query_start_loc[1:])
torch.cumsum(kv_lens_tensor,
dim=0,
dtype=torch.int32,
out=kv_start_loc[1:])
max_query_len = max(prefill_query_lens)
max_kv_len = max(prefill_seq_lens)
else:
prefill_block_tables = None
query_start_loc = None
kv_start_loc = None
max_query_len = None
max_kv_len = None
# For paged attention
if input_data.num_decode_tokens != 0:
seq_lens_tensor = torch.tensor(
input_data.seq_lens[input_data.num_prefills:],
dtype=torch.int32,
device=self.input_builder.device,
)
block_tables = make_tensor_with_pad(
self.input_data.decode_block_tables,
pad=0,
dtype=torch.int32,
device=self.input_builder.device,
)
# lowest_dim_size = block_tables.size(-1)
# if lowest_dim_size < 1024:
# padding_amount = 1024 - lowest_dim_size
# padding = torch.zeros(*block_tables.size()[:-1], padding_amount, dtype=block_tables.dtype, device=block_tables.device)
# block_tables = torch.cat((block_tables, padding), dim=-1)
else:
block_tables = torch.tensor([])
seq_lens_tensor = torch.tensor(
input_data.seq_lens[:input_data.num_prefills],
dtype=torch.int32,
device=self.input_builder.device,
)
# For multi-modal models
placeholder_index_maps = None
if len(input_data.multi_modal_inputs_list) != 0:
placeholder_index_maps = {
modality: placeholder_map.index_map()
for modality, placeholder_map in
input_data.multi_modal_placeholder_maps.items()
}
attn_metadata = VACCAttentionMetadata(
chunked_prefill=self.chunked_prefill,
seq_lens=seq_lens, #prefill_seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_kv_len=max_kv_len,
query_start_loc=query_start_loc,
kv_start_loc=kv_start_loc,
max_decode_seq_len=None,
num_prefills=input_data.num_prefills,
num_prefill_tokens=input_data.num_prefill_tokens,
num_decode_tokens=input_data.num_decode_tokens,
block_tables=block_tables,
prefill_block_tables=prefill_block_tables,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=False,
)
return attn_metadata
class VACCAttentionBackendImpl(AttentionImpl[VACCAttentionMetadata]):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
) -> None:
if blocksparse_params is not None:
raise ValueError(
"Torch SPDA does not support block-sparse attention.")
if logits_soft_cap is not None:
logger.warning_once("Torch SPDA does not support logits soft cap. "
"Outputs may be slightly off.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
self.sliding_window = sliding_window
self.kv_cache_dtype = kv_cache_dtype
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.need_mask = (self.alibi_slopes is not None
or self.sliding_window is not None)
supported_head_sizes = PagedAttention.get_supported_head_sizes()
if head_size not in supported_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {supported_head_sizes}.")
if kv_cache_dtype != "auto":
raise NotImplementedError(
"Torch SDPA backend does not support FP8 KV cache. "
"Please use xFormers backend instead.")
self.attn_type = attn_type
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: VACCAttentionMetadata, # type: ignore
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention.
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
attn_type = self.attn_type
if (attn_type == AttentionType.ENCODER
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
raise AttributeError("Encoder attention requires setting "
"encoder metadata attributes.")
elif (attn_type == AttentionType.ENCODER_DECODER
and (not attn_metadata.is_all_cross_attn_metadata_set)):
raise AttributeError("Encoder/decoder cross-attention "
"requires setting cross-attention "
"metadata attributes.")
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
if key is not None:
assert value is not None
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
else:
assert value is None
if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0):
# KV-cache during decoder-self- or
# encoder-decoder-cross-attention, but not
# during encoder attention.
#
# Even if there are no new key/value pairs to cache,
# we still need to break out key_cache and value_cache
# i.e. for later use by paged attention
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)
if (key is not None) and (value is not None):
if attn_type == AttentionType.ENCODER_DECODER:
# Update cross-attention KV cache (prefill-only)
# During cross-attention decode, key & value will be None,
# preventing this IF-statement branch from running
updated_slot_mapping = attn_metadata.cross_slot_mapping
else:
# Update self-attention KV cache (prefill/decode)
updated_slot_mapping = attn_metadata.slot_mapping
PagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache,
updated_slot_mapping,
self.kv_cache_dtype,
layer._k_scale, layer._v_scale)
if attn_type != AttentionType.ENCODER:
# Decoder self-attention supports chunked prefill.
# Encoder/decoder cross-attention requires no chunked
# prefill (100% prefill or 100% decode tokens, no mix)
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
else:
# Encoder attention - chunked prefill is not applicable;
# derive token-count from query shape & and treat them
# as 100% prefill tokens
assert attn_metadata.num_encoder_tokens is not None
num_prefill_tokens = attn_metadata.num_encoder_tokens
num_decode_tokens = 0
if attn_type == AttentionType.DECODER:
# Only enforce this shape-constraint for decoder
# self-attention
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
output = torch.empty_like(query)
if prefill_meta := attn_metadata.prefill_metadata:
assert attn_metadata.seq_lens is not None
if (kv_cache.numel() == 0
or prefill_meta.block_tables.numel() == 0):
self._run_vacc_forward(
output,
query,
key,
value,
prefill_meta,
attn_type=attn_type)
else:
# prefix-enabled attention
assert not self.need_mask
import intel_extension_for_pytorch.llm.modules as ipex_modules
output = torch.empty_like(query)
ipex_modules.PagedAttention.flash_attn_varlen_func(
output[:prefill_meta.num_prefill_tokens, :, :],
query[:prefill_meta.num_prefill_tokens, :, :],
key_cache,
value_cache,
prefill_meta.query_start_loc,
prefill_meta.kv_start_loc,
prefill_meta.max_query_len,
prefill_meta.max_kv_len,
self.scale,
True,
prefill_meta.prefill_block_tables,
self.alibi_slopes,
)
if decode_meta := attn_metadata.decode_metadata:
assert attn_type != AttentionType.ENCODER_ONLY, (
"Encoder-only models should not have decode metadata.")
# Decoding run.
# (
# seq_lens_arg,
# max_seq_len_arg,
# block_tables_arg,
# ) = decode_meta.get_seq_len_block_table_args(attn_type)
# Note:
# decode attention still use SDPA method
# reshape k/v_cache to (num_block_grp, block_grp_size, head, hidden_size)
k_cache = key_cache.view(-1, env_blk_grp_size, key_cache.shape[2], key_cache.shape[3])
v_cache = value_cache.view(-1, env_blk_grp_size, value_cache.shape[2], value_cache.shape[3])
block_per_group = env_blk_grp_size // 16
# convert block_tables to 8K group index
block_tables = (decode_meta.block_tables // block_per_group).to(torch.int32)
attn_outs = []
for i in range(decode_meta.seq_lens_tensor.shape[0]):
seq_len = decode_meta.seq_lens_tensor[i]
k_slices = k_cache[block_tables[i], ...]
k = \
torch.cat([k_slices[i, ...] for i in range(len(block_tables[i]))], dim=0)[:seq_len]
v_slices = v_cache[block_tables[i], ...]
v = \
torch.cat([v_slices[i, ...] for i in range(len(block_tables[i]))], dim=0)[:seq_len]
q = query[i : i + 1, ...]
attn_out = torch.vacc.scaled_dot_product_attention(
query=q,
key=k,
value=v,
attn_mask=None,
dropout_p=0,
is_causal=False,
is_train=False,
recompute=False,
flash_attention=False,
sm_scale=self.scale,
)
attn_outs.append(attn_out)
output = torch.cat(attn_outs, dim=0)
# '''
# PagedAttention.forward_decode(
# output[attn_metadata.num_prefill_tokens:, :, :],
# query[attn_metadata.num_prefill_tokens:, :, :],
# key_cache,
# value_cache,
# block_tables_arg,
# seq_lens_arg,
# max_seq_len_arg,
# self.kv_cache_dtype,
# self.num_kv_heads,
# self.scale,
# self.alibi_slopes,
# layer._k_scale,
# layer._v_scale,
# )
# Reshape the output tensor.
return output.view(-1, self.num_heads * self.head_size)
def _run_vacc_forward(
self,
output: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: VACCAttentionMetadata,
attn_type: AttentionType = AttentionType.DECODER,
):
# if self.num_kv_heads != self.num_heads:
# key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
# value = value.repeat_interleave(self.num_queries_per_kv, dim=1)
attn_masks = attn_metadata.get_attn_bias(attn_type)
if attn_masks is None:
if self.alibi_slopes is not None:
attn_masks = _make_alibi_bias(
self.alibi_slopes, query.dtype,
attn_metadata.seq_lens) # type: ignore
elif self.sliding_window is not None:
assert attn_metadata.seq_lens is not None
attn_masks = _make_sliding_window_bias(
attn_metadata.seq_lens, self.sliding_window,
query.dtype) # type: ignore
else:
seq_lens, _ = attn_metadata.get_seq_lens(attn_type)
attn_masks = [None] * len(seq_lens)
attn_metadata.set_attn_bias(attn_masks, attn_type)
causal_attn = (attn_type == AttentionType.DECODER)
seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type)
start_q, start_kv = 0, 0
for seq_len_q, seq_len_kv, mask in zip(seq_lens_q, seq_lens_kv,
attn_masks):
end_q = start_q + seq_len_q
end_kv = start_kv + seq_len_kv
sub_out=torch.vacc.scaled_dot_product_attention(
query[start_q:end_q,:, :],
key[start_kv:end_kv,:, :],
value[start_kv:end_kv,:, :].contiguous(),
attn_mask=None,
dropout_p=0.0,
is_causal=True, #causal_attn and not self.need_mask,
is_train=False,
recompute=False,
flash_attention=False,
sm_scale=self.scale)
output[ start_q:end_q,:, :] = sub_out
start_q, start_kv = end_q, end_kv
return output
def _make_alibi_bias(
alibi_slopes: torch.Tensor,
dtype: torch.dtype,
seq_lens: List[int],
) -> List[torch.Tensor]:
attn_biases: List[torch.Tensor] = []
for seq_len in seq_lens:
bias = torch.arange(seq_len, dtype=dtype)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(seq_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
bias = bias[None, :] - bias[:, None]
num_heads = alibi_slopes.shape[0]
bias = bias[None, :].repeat((num_heads, 1, 1))
bias.mul_(alibi_slopes[:, None, None]).unsqueeze_(0)
inf_mask = torch.empty(
(1, seq_len, seq_len),
dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1)
attn_biases.append((bias + inf_mask).to(dtype))
return attn_biases
def _make_sliding_window_bias(
seq_lens: List[int],
window_size: Optional[int],
dtype: torch.dtype,
) -> List[torch.Tensor]:
attn_biases: List[torch.Tensor] = []
for seq_len in seq_lens:
tensor = torch.full(
(1, seq_len, seq_len),
dtype=dtype,
fill_value=1,
)
shift = 0
mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore
if window_size is not None:
mask = torch.triu(mask, diagonal=shift - window_size + 1)
mask = torch.log(mask)
attn_biases.append(mask.to(dtype))
return attn_biases

View File

@@ -0,0 +1,847 @@
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
from itertools import accumulate
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
from vllm.multimodal import MultiModalPlaceholderMap
try:
from flashinfer import BatchDecodeMlaWithPagedKVCacheWrapper
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
except ImportError:
BatchDecodeMlaWithPagedKVCacheWrapper = None
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
import torch
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata,
AttentionMetadataBuilder,
AttentionState, AttentionType)
from vllm.attention.backends.mla.common import MLACommonImpl, MLACommonMetadata
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx,
is_block_tables_empty)
#from vllm.attention.ops.paged_attn import PagedAttention
from vllm_vacc.vllm.attention.ops.vacc_paged_attn import VaccPagedAttention as PagedAttention
# from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
# import time, os
if TYPE_CHECKING:
from vllm_vacc.vllm.worker.vacc_model_runner import (ModelInputForVACCBuilder,
ModelInputForVACCWithSamplingMetadata)
class VACCMLABackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "TORCH_VACC"
@staticmethod
def get_impl_cls() -> Type["VACCMLAImpl"]:
return VACCMLAImpl
@staticmethod
def get_metadata_cls() -> Type["AttentionMetadata"]:
return VACCMLAMetadata
@staticmethod
def get_builder_cls() -> Type["VACCMLAMetadataBuilder"]:
return VACCMLAMetadataBuilder
@staticmethod
def get_state_cls() -> Type["VACCMLAState"]:
return VACCMLAState
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int, # assumed to be 1 for MLA
head_size: int,
) -> Tuple[int, ...]:
return (num_blocks, block_size, head_size)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)
@staticmethod
def get_supported_head_sizes() -> List[int]:
return [576]
class VACCMLAState(AttentionState):
def __init__(self, runner):
self.runner = runner
self._is_graph_capturing = False
@contextmanager
def graph_capture(self, max_batch_size: int):
self._is_graph_capturing = True
self._graph_slot_mapping = torch.full((max_batch_size, ),
PAD_SLOT_ID,
dtype=torch.long,
device=self.runner.device)
self._graph_seq_lens = torch.ones(max_batch_size,
dtype=torch.int32,
device=self.runner.device)
self._graph_block_tables = torch.from_numpy(
self.runner.graph_block_tables).to(device=self.runner.device)
self._positions = torch.zeros((max_batch_size, ),
dtype=torch.long,
device=self.runner.device)
yield
self._is_graph_capturing = False
del self._graph_slot_mapping
del self._graph_seq_lens
del self._graph_block_tables
del self._positions
def graph_clone(self, batch_size: int):
assert self._is_graph_capturing
return self.__class__(self.runner)
def graph_capture_get_metadata_for_batch(
self, batch_size: int, is_encoder_decoder_model: bool = False):
assert self._is_graph_capturing
attn_metadata = self.runner.attn_backend.make_metadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=batch_size,
slot_mapping=self._graph_slot_mapping[:batch_size],
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
seq_lens=None,
seq_lens_tensor=self._graph_seq_lens[:batch_size],
# max_query_len=1,
# max_decode_query_len=1,
max_prefill_seq_len=0,
max_decode_seq_len=self.runner.max_seq_len_to_capture,
query_start_loc=None,
seq_start_loc=None,
context_lens_tensor=None,
block_tables=self._graph_block_tables[:batch_size],
use_cuda_graph=True,
input_positions=self._positions[:batch_size],
head_dim=self.runner.model_config.get_head_size())
if is_encoder_decoder_model:
raise NotImplementedError(
"VACCMLAState does not support encoder/decoder yet")
return attn_metadata
def get_graph_input_buffers(self,
attn_metadata,
is_encoder_decoder_model: bool = False):
input_buffers = {
"slot_mapping": attn_metadata.slot_mapping,
"seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
"block_tables": attn_metadata.decode_metadata.block_tables,
"input_positions": attn_metadata.decode_metadata.input_positions,
}
if is_encoder_decoder_model:
raise NotImplementedError(
"VACCMLAState does not support encoder/decoder yet")
return input_buffers
def prepare_graph_input_buffers(self,
input_buffers,
attn_metadata,
is_encoder_decoder_model: bool = False):
input_positions = attn_metadata.input_positions
num_positions = input_positions.shape[0]
input_buffers["seq_lens_tensor"].copy_(
attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
input_buffers["block_tables"].copy_(
attn_metadata.decode_metadata.block_tables, non_blocking=True)
# CUDA graph buffer is padded so only perform a partial copy based on
# num_positions
input_buffers["input_positions"][:num_positions].copy_(
input_positions, non_blocking=True)
if is_encoder_decoder_model:
raise NotImplementedError(
"VACCMLAState does not support encoder/decoder yet")
def begin_forward(self, model_input):
return
@dataclass
class VACCMLAMetadata(MLACommonMetadata):
"""Metadata for VACCMLAMetadata.
NOTE: Any python object stored here is not updated when it is
cuda-graph replayed. If you have values that need to be changed
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]]
# seq_lens stored as a tensor.
seq_lens_tensor: Optional[torch.Tensor]
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_prefill_seq_len: int
# Maximum sequence length among decode batch. 0 if there are prefill
# requests only.
max_decode_seq_len: int
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor]
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables: Optional[torch.Tensor]
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool
# Maximum query length in the batch.
max_query_len: Optional[int] = None
input_positions: Optional[torch.Tensor] = None
# Max number of query tokens among request in the batch.
max_decode_query_len: Optional[int] = None
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
query_start_loc: Optional[torch.Tensor] = None
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor] = None
_cached_prefill_metadata: Optional["VACCMLAMetadata"] = None
_cached_decode_metadata: Optional["VACCMLAMetadata"] = None
num_prefill_tokens: int
num_kv_splits: int = 4 # TODO(lucas) add heuristic
attn_logits: Optional[torch.Tensor] = None
req_idx: Optional[torch.Tensor] = None
# The dimension of the attention heads
head_dim: Optional[int] = None
def __post_init__(self):
supported_head_sizes = VACCMLABackend.get_supported_head_sizes()
if self.head_dim is not None and self.head_dim \
not in supported_head_sizes:
raise ValueError(
f"Only {supported_head_sizes} are supported for head_dim,",
f"received {self.head_dim}.")
@property
def prefill_metadata(self) -> Optional["VACCMLAMetadata"]:
if self.num_prefills == 0:
return None
if self._cached_prefill_metadata is not None:
return self._cached_prefill_metadata
assert self.seq_lens is not None
assert self.seq_lens_tensor is not None
# Compute some attn_metadata fields which default to None
query_start_loc = (None if self.query_start_loc is None else
self.query_start_loc[:self.num_prefills + 1])
slot_mapping = (None if self.slot_mapping is None else
self.slot_mapping[:self.num_prefill_tokens])
seq_lens = (None if self.seq_lens is None else
self.seq_lens[:self.num_prefills])
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor[:self.num_prefills])
seq_start_loc = (None if self.seq_start_loc is None else
self.seq_start_loc[:self.num_prefills + 1])
context_lens_tensor = (None if self.context_lens_tensor is None else
self.context_lens_tensor[:self.num_prefills])
block_tables = (None if self.block_tables is None else
self.block_tables[:self.num_prefills])
input_positions = (None if self.input_positions is None else
self.input_positions[:self.num_prefill_tokens])
self._cached_prefill_metadata = VACCMLAMetadata(
num_prefills=self.num_prefills,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
input_positions=input_positions,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_prefill_seq_len=None,
max_decode_seq_len=0,
query_start_loc=query_start_loc,
seq_start_loc=seq_start_loc,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=False,
head_dim=self.head_dim)
return self._cached_prefill_metadata
@property
def decode_metadata(self) -> Optional["VACCMLAMetadata"]:
if self.num_decode_tokens == 0:
return None
if self._cached_decode_metadata is not None:
return self._cached_decode_metadata
assert self.seq_lens_tensor is not None
# Compute some attn_metadata fields which default to None
slot_mapping = (None if self.slot_mapping is None else
self.slot_mapping[self.num_prefill_tokens:])
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor[self.num_prefills:])
block_tables = (None if self.block_tables is None else
self.block_tables[self.num_prefills:])
input_positions = (None if self.input_positions is None else
self.input_positions[self.num_prefill_tokens:])
self._cached_decode_metadata = VACCMLAMetadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
seq_lens=self.seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_decode_query_len=self.max_decode_query_len,
max_query_len=self.max_query_len,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len,
# Batch may be composed of prefill|decodes, adjust query start
# indices to refer to the start of decodes. E.g.
# in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
query_start_loc=(self.query_start_loc[self.num_prefills:] -
self.query_start_loc[self.num_prefills])
if self.query_start_loc is not None else None,
seq_start_loc=self.seq_start_loc[self.num_prefills:]
if self.seq_start_loc is not None else None,
context_lens_tensor=None,
block_tables=block_tables,
use_cuda_graph=self.use_cuda_graph,
input_positions=input_positions,
head_dim=self.head_dim)
return self._cached_decode_metadata
def advance_step(self,
model_input: "ModelInputForVACCWithSamplingMetadata",
sampled_token_ids: Optional[torch.Tensor],
block_size: int,
num_seqs: int,
num_queries: int,
turn_prefills_into_decodes: bool = False):
"""
Update metadata in-place to advance one decode step.
"""
# When using cudagraph, the num_seqs is padded to the next captured
# batch sized, but num_queries tracks the actual number of requests in
# the batch. For --enforce-eager mode, num_seqs == num_queries
if num_seqs != num_queries:
assert num_seqs > num_queries
assert self.use_cuda_graph
if turn_prefills_into_decodes:
# When Mutli-Step is enabled with Chunked-Prefill, prefills and
# decodes are scheduled together. In the first step, all the
# prefills turn into decodes. This update reflects that
# conversion.
assert self.num_decode_tokens + self.num_prefills == num_seqs
self.num_decode_tokens += self.num_prefills
self.num_prefills = 0
# self.num_prefill_tokens = 0
# self.max_prefill_seq_len = 0
self.max_query_len = 1
self.slot_mapping = self.slot_mapping[:num_seqs]
else:
assert self.seq_lens is not None
assert self.max_decode_seq_len == max(self.seq_lens)
assert self.num_prefills == 0
assert self.num_prefill_tokens == 0
assert self.num_decode_tokens == num_seqs
assert self.slot_mapping.shape == (num_seqs, )
assert self.seq_lens is not None
assert len(self.seq_lens) == num_seqs
assert self.seq_lens_tensor is not None
assert self.seq_lens_tensor.shape == (num_seqs, )
# assert self.max_query_len == 1
# assert self.max_prefill_seq_len == 0
assert self.query_start_loc is not None
assert self.query_start_loc.shape == (num_queries + 1, )
assert self.seq_start_loc is not None
assert self.seq_start_loc.shape == (num_seqs + 1, )
assert self.context_lens_tensor is not None
assert self.context_lens_tensor.shape == (num_queries, )
assert self.block_tables is not None
assert self.block_tables.shape[0] == num_seqs
# Update query lengths. Note that we update only queries and not seqs,
# since tensors may be padded due to captured cuda graph batch size
for i in range(num_queries):
self.seq_lens[i] += 1
# self.max_decode_seq_len = None
ops.advance_step_flashattn(num_seqs=num_seqs,
num_queries=num_queries,
block_size=block_size,
input_tokens=model_input.input_tokens,
sampled_token_ids=sampled_token_ids,
input_positions=model_input.input_positions,
seq_lens=self.seq_lens_tensor,
slot_mapping=self.slot_mapping,
block_tables=self.block_tables)
class VACCMLAMetadataBuilder(AttentionMetadataBuilder[VACCMLAMetadata]):
def __init__(self, input_builder: "ModelInputForVACCBuilder"):
self.chunked_prefill = True
if hasattr(input_builder, 'chunked_prefill'):
self.chunked_prefill = input_builder.chunked_prefill
self.input_builder = input_builder
self.runner = input_builder.runner
self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
def prepare(self):
self.slot_mapping: List[int] = []
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
self.block_tables: List[List[int]] = []
self.curr_seq_lens: List[int] = []
self.input_positions: List[int] = []
self.multimodal_placeholder_maps: Dict[
str,
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
self.num_prefills = 0
self.num_prefill_tokens = 0
self.num_decode_tokens = 0
self.has_prefix_cache_hit = False
def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):
"""Build attention metadata with on-device tensors.
Args:
seq_lens: The maybe padded sequence lengths of the input sequences.
query_lens: The query lengths of the input sequences.
cuda_graph_pad_size: The padding size for cuda graph.
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
"""
self.input_data = self.input_builder.input_data
self.slot_mapping=self.input_data.slot_mapping
self.context_lens= self.input_data.context_lens
if self.input_data.num_prefill_tokens !=0:
self.block_tables = self.input_data.prefill_block_tables
else:
self.block_tables= self.input_data.decode_block_tables
self.input_positions= self.input_data.input_positions
self.prefill_seq_lens = seq_lens[0:self.input_data.num_prefills]
self.num_prefills = self.input_data.num_prefills
self.num_prefill_tokens = self.input_data.num_prefill_tokens
self.num_decode_tokens = self.input_data.num_decode_tokens
device = self.runner.device
use_captured_graph = cuda_graph_pad_size != -1
# max_query_len = max(query_lens)
# decode_query_lens = query_lens[self.num_prefills:]
# if len(decode_query_lens) > 0:
# max_decode_query_len = max(decode_query_lens)
# else:
# max_decode_query_len = 1
# max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
# max_decode_seq_len = max(self.curr_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens
query_start_loc = list(accumulate(query_lens, initial=0))
seq_start_loc = list(accumulate(seq_lens, initial=0))
num_seqs = len(seq_lens)
if use_captured_graph:
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
self.block_tables.extend([] * cuda_graph_pad_size)
num_decode_tokens = batch_size - self.num_prefill_tokens
block_tables = self._get_graph_runner_block_tables(
num_seqs, self.block_tables)
else:
block_tables = make_tensor_with_pad(
self.block_tables,
pad=0,
dtype=torch.int,
device=device,
)
# assert max_query_len > 0, ("query_lens: {}".format(query_lens))
assert device is not None
context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
device, self.runner.pin_memory)
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
self.runner.pin_memory)
input_positions = async_tensor_h2d(self.input_positions, torch.int,
device, self.runner.pin_memory)
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.int,
device, self.runner.pin_memory)
query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
device,
self.runner.pin_memory)
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
device, self.runner.pin_memory)
placeholder_index_maps = {
modality: placeholder_map.index_map()
for modality, placeholder_map in
self.multimodal_placeholder_maps.items()
}
return VACCMLAMetadata(
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=True,
input_positions=input_positions,
seq_lens_tensor=seq_lens_tensor,
# max_query_len=max_query_len,
# max_decode_query_len=None,
max_prefill_seq_len=None,
max_decode_seq_len=None,
query_start_loc=query_start_loc_tensor,
seq_start_loc=seq_start_loc_tensor,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=use_captured_graph,
num_kv_splits=4, # TODO(lucas) add heuristic
head_dim=self.runner.model_config.get_head_size(),
)
class VACCMLAImpl(MLACommonImpl[VACCMLAMetadata]):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments
**kwargs) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap, attn_type,
kv_sharing_target_layer_name, **kwargs)
unsupported_features = [
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
]
if any(unsupported_features):
raise NotImplementedError(
"VACCMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, "
"logits_soft_cap")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"VACCMLAImpl")
def extract_weights(self):
weights = {}
if hasattr(self, 'W_Q'):
weights["W_Q"] = self.W_Q
if hasattr(self, 'W_Q_scales'):
weights["W_Q_scales"] = self.W_Q_scales
if hasattr(self, 'W_QR'):
weights['W_QR'] = self.W_QR
if hasattr(self, 'W_QR_scales'):
weights["W_QR_scales"] = self.W_QR_scales
if hasattr(self, 'W_Q_QR'):
weights["W_Q_QR"] = self.W_Q_QR
if hasattr(self, 'W_Q_QR_scales'):
weights["W_Q_QR_scales"] = self.W_Q_QR_scales
if hasattr(self, 'W_UK'):
weights['W_UK'] = self.W_UK
if hasattr(self, 'W_UK_scales'):
weights['W_UK_scales'] = self.W_UK_scales
if hasattr(self, 'W_Q_UK_scales'):
weights['W_Q_UK_scales'] = self.W_Q_UK_scales
if hasattr(self, 'W_UV'):
weights['W_UV'] = self.W_UV
if hasattr(self, 'W_UV_scales'):
weights['W_UV_scales'] = self.W_UV_scales
if hasattr(self, 'W_UV_O'):
weights['W_UV_O'] = self.W_UV_O
if hasattr(self, 'W_UV_O_scales'):
weights['W_UV_O_scales'] = self.W_UV_O_scales
return weights
def _forward_prefill(
self,
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: VACCMLAMetadata,
) -> torch.Tensor:
assert isinstance(attn_metadata, VACCMLAMetadata)
kv_nope = self.kv_b_proj(kv_c_normed)[0]\
.view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
v = v.contiguous()
# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim
# v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
# value=0)
# attn_output = torch.vacc.scaled_dot_product_attention(
# query=q,
# key=k,
# value=v_padded,
# attn_mask=None,
# dropout_p=0,
# is_causal=True,
# is_train=False,
# recompute=False,
# flash_attention=True,
# sm_scale=self.scale
# )
# attn_output = attn_output\
# .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\
# .reshape(-1, self.num_heads * v.shape[-1])
seq_lens = attn_metadata.prefill_metadata.seq_lens
if len(seq_lens) == 1:
# Vacc supports different head dim of v and qk.
attn_output = torch.vacc.scaled_dot_product_attention(
query=q,
key=k,
value=v,
attn_mask=None,
dropout_p=0,
is_causal=True,
is_train=False,
recompute=False,
flash_attention=False,
sm_scale=self.scale
)
attn_out = attn_output.view(-1, self.num_heads * v.shape[-1])
else:
attn_outs = []
start = 0
for seq in seq_lens:
end = start + seq
attn_out = torch.vacc.scaled_dot_product_attention(
query=q[start:end, :],
key=k[start:end, :],
value=v[start:end, :],
attn_mask=None,
dropout_p=0,
is_causal=True,
is_train=False,
recompute=False,
flash_attention=False,
sm_scale=self.scale
)
start = end
attn_outs.append(attn_out)
attn_out = torch.cat(attn_outs, dim=0).view(-1, self.num_heads * v.shape[-1])
return self.o_proj(attn_out)[0]
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: VACCMLAMetadata,
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError("FP8 Triton MLA not yet supported")
decode_meta = attn_metadata.decode_metadata
assert decode_meta is not None
B = q_nope.shape[0]
q = torch.cat([q_nope, q_pe], dim=-1)
o = torch.zeros(B,
self.num_heads,
self.kv_lora_rank,
dtype=q.dtype,
device=q.device)
# Add a head dim of 1
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2)
# print(f"kv_c_and_k_pe_cache: {kv_c_and_k_pe_cache.shape} ")
kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]
# Run MQA using paged_attention
# o = torch.vacc.paged_attention(
# query=q,
# key_cache=kv_c_and_k_pe_cache,
# value_cache=kv_c_cache,
# block_table=decode_meta.block_tables,
# seq_len=decode_meta.seq_lens_tensor,
# out=o,
# sm_scale=self.scale
# )
# Run MQA using spda
# t0 = time.time()
o = vacc_paged_attention_naive(
q,
kv_c_and_k_pe_cache,
kv_c_cache,
block_table = decode_meta.block_tables,
# seq_lens = decode_meta.seq_lens_tensor,
seq_lens=decode_meta.seq_lens,
out = o,
sm_scale=self.scale)
# print(f'{os.getpid()} paged_atten(seq: {decode_meta.seq_lens}) time: {time.time() - t0}')
return self._v_up_proj_and_o_proj(o)
def vacc_paged_attention_naive(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_table: torch.Tensor,
# seq_lens: torch.Tensor,
seq_lens: int,
out: Optional[torch.Tensor] = None,
sm_scale = -1
) -> torch.Tensor:
# gurantee batch=1 perf
if len(seq_lens) == 1:
k = key_cache.view(-1, key_cache.shape[2], key_cache.shape[3])[:seq_lens[0]]
v = value_cache.view(-1, value_cache.shape[2], value_cache.shape[3])[:seq_lens[0]]
attn_out = torch.vacc.scaled_dot_product_attention(
query=query,
key=k,
value=v,
attn_mask=None,
dropout_p=0,
is_causal=False,
is_train=False,
recompute=False,
flash_attention=False,
sm_scale=sm_scale
)
else:
# t0 = time.time()
attn_outs = []
for i in range(len(seq_lens)):
k_slices = key_cache[block_table[i], :, :, :]
k = torch.cat([k_slices[i, :, :, :].unsqueeze(1) for i in range(len(block_table[i]))], dim=0)
k = k.view(-1, key_cache.shape[2], key_cache.shape[3])[:seq_lens[i]]
v_slices = value_cache[block_table[i], :, :, :]
v = torch.cat([v_slices[i, :, :, :].unsqueeze(1) for i in range(len(block_table[i]))], dim=0)
v = v.view(-1, value_cache.shape[2], value_cache.shape[3])[:seq_lens[i]]
attn_out = torch.vacc.scaled_dot_product_attention(
query=query[i:i+1,:,:],
key=k,
value=v,
attn_mask=None,
dropout_p=0,
is_causal=False,
is_train=False,
recompute=False,
flash_attention=False,
sm_scale=sm_scale
)
attn_outs.append(attn_out)
attn_out = torch.cat(attn_outs, dim=0)
# print(f'{os.getpid()} call spda(seq: {seq_lens}) time: {time.time() - t0}')
return attn_out
# MLA single op impl
def vacc_paged_attention_naive_singleop(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seq_lens,
block_table = None,
out: torch.Tensor = None,
sm_scale = -1
) -> torch.Tensor:
k = key_cache.view(-1, key_cache.shape[2], key_cache.shape[3])[:seq_lens]
v = value_cache.view(-1, value_cache.shape[2], value_cache.shape[3])[:seq_lens].squeeze(1)
pe_cache = k[..., 512:].squeeze(1)
print(f'q:{query[..., :512].shape} v:{v.shape} pe_cache:{pe_cache.shape}')
q_nope_kv_c = torch.einsum("shc,tc->sht", query[..., :512], v)
q_pe_k_pe = torch.einsum("shr,tr->sht", query[..., 512:], pe_cache)
scores = (q_nope_kv_c + q_pe_k_pe) * sm_scale
scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(query)
o = torch.einsum("sht,tc->shc", scores, v)
return o

View File

View File

@@ -0,0 +1,160 @@
from dataclasses import dataclass
from typing import List, Optional, Tuple
import torch
from vllm.triton_utils import HAS_TRITON
if HAS_TRITON:
from vllm.attention.ops.prefix_prefill import context_attention_fwd
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE = 512
@dataclass
class PagedAttentionMetadata:
"""Metadata for PagedAttention."""
# (batch_size,). The length of sequences (entire tokens seen so far) per
# sequence.
seq_lens_tensor: Optional[torch.Tensor]
# Maximum sequence length in the batch. 0 if it is prefill-only batch.
max_decode_seq_len: int
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables: Optional[torch.Tensor]
class VaccPagedAttention:
@staticmethod
def get_supported_head_sizes() -> List[int]:
return [32, 64, 80, 96, 112, 120, 128, 192, 256]
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (2, num_blocks, block_size * num_kv_heads * head_size)
@staticmethod
def split_kv_cache(
kv_cache: torch.Tensor,
num_kv_heads: int,
head_size: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
# x = 16 // kv_cache.element_size()
num_blocks = kv_cache.shape[1]
key_cache = kv_cache[0]
key_cache = key_cache.view(num_blocks, -1,num_kv_heads, head_size)
value_cache = kv_cache[1]
value_cache = value_cache.view(num_blocks, -1, num_kv_heads, head_size)
return key_cache, value_cache
@staticmethod
def write_to_paged_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
) -> None:
# list_from_tensor = slot_mapping.tolist()
torch.vacc.reshape_and_cache_attention(key,key_cache,slot_mapping)
torch.vacc.reshape_and_cache_attention(value,value_cache,slot_mapping)
@staticmethod
def forward_decode(
output: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
max_seq_len: int,
kv_cache_dtype: str,
num_kv_heads: int,
scale: float,
alibi_slopes: Optional[torch.Tensor],
k_scale: float,
v_scale: float,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
) -> torch.Tensor:
torch.vacc.paged_attention(query,key_cache,value_cache,block_tables,seq_lens,-1,output)
return output
@staticmethod
def forward_prefix(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache_dtype: str,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
query_start_loc: torch.Tensor,
seq_lens_tensor: torch.Tensor,
context_lens: torch.Tensor,
max_query_len: int,
alibi_slopes: Optional[torch.Tensor],
sliding_window: Optional[int],
k_scale: float,
v_scale: float,
) -> torch.Tensor:
output = torch.empty_like(query)
context_attention_fwd(
query,
key,
value,
output,
kv_cache_dtype,
key_cache,
value_cache,
block_tables,
# query_start_loc is (batch_size + 1,)
query_start_loc[:-1],
seq_lens_tensor,
context_lens,
max_query_len,
k_scale,
v_scale,
alibi_slopes,
sliding_window,
)
return output
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
src_key_cache = src_kv_cache[0]
dst_key_cache = dst_kv_cache[0]
torch.vacc.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
src_value_cache = src_kv_cache[1]
dst_value_cache = dst_kv_cache[1]
torch.vacc.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
key_caches = [kv_cache[0] for kv_cache in kv_caches]
value_caches = [kv_cache[1] for kv_cache in kv_caches]
torch.vacc.copy_blocks(key_caches, value_caches, src_to_dists)

154
vllm_vacc/vllm/config.py Normal file
View File

@@ -0,0 +1,154 @@
# SPDX-License-Identifier: Apache-2.0
import ast
import copy
import enum
import hashlib
import inspect
import json
import re
import sys
import textwrap
import warnings
from collections import Counter
from contextlib import contextmanager
from dataclasses import (MISSING, dataclass, field, fields, is_dataclass,
replace)
from importlib.util import find_spec
from pathlib import Path
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
Optional, Protocol, TypeVar, Union, get_args)
import torch
from pydantic import BaseModel, Field, PrivateAttr
from torch.distributed import ProcessGroup, ReduceOp
from transformers import PretrainedConfig
import vllm.envs as envs
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS,
QuantizationMethods,
get_quantization_config)
from vllm.model_executor.models import ModelRegistry
from vllm.platforms import CpuArchEnum, current_platform
from vllm.config.model import _STR_DTYPE_TO_TORCH_DTYPE
logger = init_logger(__name__)
def ModelConfig___verify_quantization(self) -> None:
supported_quantization = QUANTIZATION_METHODS
optimized_quantization_methods = [
"fp8", "modelopt", "gptq_marlin_24", "gptq_marlin",
"awq_marlin", "fbgemm_fp8", "compressed-tensors", "experts_int8",
"quark", "modelopt_fp4", "bitblas"#, "gptq_bitblas"
]
if self.quantization is not None:
self.quantization = self.quantization.lower()
# Parse quantization method from the HF model config, if available.
quant_cfg = self._parse_quant_hf_config(self.hf_config)
if quant_cfg is None and (text_config := getattr(
self.hf_config, "text_config", None)):
# Check the text config as well for multi-modal models.
quant_cfg = self._parse_quant_hf_config(text_config)
if quant_cfg is not None:
quant_method = quant_cfg.get("quant_method", "").lower()
quant_method = quant_method.replace("compressed_tensors",
"compressed-tensors")
quant_cfg["quant_method"] = quant_method
# Quantization methods which are overrides (i.e. they have a
# `override_quantization_method` method) must be checked in order
# of preference (this is particularly important for GPTQ).
overrides = [
# "marlin",
"bitblas",
"gptq_marlin_24",
"gptq_marlin",
# "gptq_bitblas",
"awq_marlin",
"ipex",
"moe_wna16",
"modelopt",
"modelopt_fp4",
"petit_nvfp4",
]
quantization_methods = [
q for q in supported_quantization if q not in overrides
]
# Any custom overrides will be in quantization_methods so we place
# them at the start of the list so custom overrides have preference
# over the built in ones.
quantization_methods = quantization_methods + overrides
# Detect which checkpoint is it
for name in quantization_methods:
method = get_quantization_config(name)
quantization_override = method.override_quantization_method(
quant_cfg, self.quantization)
if quantization_override is not None:
# Raise error if the override is not custom (custom would
# be in QUANTIZATION_METHODS but not QuantizationMethods)
# and hasn't been added to the overrides list.
if (name in get_args(QuantizationMethods)
and name not in overrides):
raise ValueError(
f"Quantization method {name} is an override but "
"is has not been added to the `overrides` list "
"above. This is necessary to ensure that the "
"overrides are checked in order of preference.")
quant_method = quantization_override
self.quantization = quantization_override
break
# Verify quantization configurations.
if self.quantization is None:
self.quantization = quant_method
elif self.quantization != quant_method:
raise ValueError(
"Quantization method specified in the model config "
f"({quant_method}) does not match the quantization "
f"method specified in the `quantization` argument "
f"({self.quantization}).")
if self.quantization is not None:
if self.quantization not in supported_quantization:
raise ValueError(
f"Unknown quantization method: {self.quantization}. Must "
f"be one of {supported_quantization}.")
from vllm.platforms import current_platform
current_platform.verify_quantization(self.quantization)
if self.quantization not in optimized_quantization_methods:
logger.warning(
"%s quantization is not fully "
"optimized yet. The speed can be slower than "
"non-quantized models.", self.quantization)
def _get_head_dtype(config: PretrainedConfig, dtype: torch.dtype,
runner_type: str) -> torch.dtype:
head_dtype: Optional[Union[str,
torch.dtype]] = getattr(config, "head_dtype",
None)
if head_dtype == "model":
return dtype
elif isinstance(head_dtype, str):
head_dtype = head_dtype.lower()
if head_dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
raise ValueError(f"Unknown dtype: {head_dtype!r}")
return _STR_DTYPE_TO_TORCH_DTYPE[head_dtype]
elif isinstance(head_dtype, torch.dtype):
return head_dtype
elif head_dtype is None:
if torch.float32 not in current_platform.supported_dtypes:
return dtype
if runner_type == "pooling":
return torch.float16
return dtype
else:
raise ValueError(f"Unknown dtype: {head_dtype}")

View File

@@ -0,0 +1,52 @@
#####################################################
## 1. use for Memory-Recycler
## .model_infos
## ['deepseek_mtp',]
##
## 2. waitting...
######################################################
import os
class ConfigManager():
def __init__(self):
self._config_name = ".model_infos"
def update_model_infos(self, model_infos : str):
from pathlib import Path
workspace_path = Path.cwd()
bootinfo_config = f'{workspace_path}/{self._config_name}'
try:
with open(bootinfo_config, 'w') as w:
w.write(model_infos)
except Exception as e:
print("[WARN] write model_infos fail, caused by ", e)
raise False
def get_model_infos(self):
from pathlib import Path
workspace_path = Path.cwd()
bootinfo_config = f'{workspace_path}/{self._config_name}'
bootinfo_inited = os.path.exists(bootinfo_config)
runner_model_infos = "default"
if bootinfo_inited:
try:
with open(bootinfo_config) as w:
runner_model_infos = w.readline()
except Exception as e:
print("[WARN] model_infos load fail ", e)
return runner_model_infos
config_manager = None
def vllm_vacc_config_manager():
global config_manager
if config_manager is None:
config_manager = ConfigManager()
return config_manager

View File

View File

View File

@@ -0,0 +1,407 @@
# SPDX-License-Identifier: Apache-2.0
import math
from typing import List, Optional
from vllm.core.block.common import BlockList
from vllm.core.block.interfaces import Block, DeviceAwareBlockAllocator
from vllm.utils import Device, cdiv, chunk_list
class BlockTable:
"""A class to manage blocks for a specific sequence.
The BlockTable maps a sequence of tokens to a list of blocks, where each
block represents a contiguous memory allocation for a portion of the
sequence. The blocks are managed by a DeviceAwareBlockAllocator, which is
responsible for allocating and freeing memory for the blocks.
Args:
block_size (int): The maximum number of tokens that can be stored in a
single block.
block_allocator (DeviceAwareBlockAllocator): The block allocator used to
manage memory for the blocks.
_blocks (Optional[List[Block]], optional): An optional list of existing
blocks to initialize the BlockTable with. If not provided, an empty
BlockTable is created.
max_block_sliding_window (Optional[int], optional): The number of
blocks to keep around for each sequence. If None, all blocks
are kept (eg., when sliding window is not used).
It should at least fit the sliding window size of the model.
Attributes:
_block_size (int): The maximum number of tokens that can be stored in a
single block.
_allocator (DeviceAwareBlockAllocator): The block allocator used to
manage memory for the blocks.
_blocks (Optional[List[Block]]): The list of blocks managed by this
BlockTable.
_num_full_slots (int): The number of tokens currently stored in the
blocks.
"""
def __init__(
self,
block_size: int,
block_allocator: DeviceAwareBlockAllocator,
_blocks: Optional[List[Block]] = None,
max_block_sliding_window: Optional[int] = None,
):
self._block_size = block_size
self._allocator = block_allocator
if _blocks is None:
_blocks = []
self._blocks: BlockList = BlockList(_blocks)
self._max_block_sliding_window = max_block_sliding_window
self._num_full_slots = self._get_num_token_ids()
@staticmethod
def get_num_required_blocks(token_ids: List[int],
block_size: int,
num_lookahead_slots: int = 0) -> int:
"""Calculates the minimum number of blocks required to store a given
sequence of token IDs along with any look-ahead slots that may be
required (like in multi-step + chunked-prefill).
This assumes worst-case scenario, where every block requires a new
allocation (e.g. ignoring prefix caching).
Args:
token_ids (List[int]): The sequence of token IDs to be stored.
block_size (int): The maximum number of tokens that can be stored in
a single block.
num_lookahead_slots (int): look-ahead slots that the sequence may
require.
Returns:
int: The minimum number of blocks required to store the given
sequence of token IDs along with any required look-ahead slots.
"""
return cdiv(len(token_ids) + num_lookahead_slots, block_size)
def allocate(self,
token_ids: List[int],
device: Device = Device.GPU,
extra_hash: Optional[int] = None,
seq_id: Optional[int] = None) -> None:
"""Allocates memory blocks for storing the given sequence of token IDs.
This method allocates the required number of blocks to store the given
sequence of token IDs.
Args:
token_ids (List[int]): The sequence of token IDs to be stored.
device (Device, optional): The device on which the blocks should be
allocated. Defaults to Device.GPU.
extra_hash (Optional[int]): The hash value of additional
factors, such as adapters, that influence the block hash
in the prefixcaching block.
"""
assert not self._is_allocated
assert token_ids
blocks = self._allocate_blocks_for_token_ids(prev_block=None,
token_ids=token_ids,
device=device,
extra_hash=extra_hash,
seq_id=seq_id)
self.update(blocks)
self._num_full_slots = len(token_ids)
def update(self, blocks: List[Block]) -> None:
"""Resets the table to the newly provided blocks
(with their corresponding block ids)
"""
self._blocks.update(blocks)
def append_token_ids(self,
token_ids: List[int],
num_lookahead_slots: int = 0,
num_computed_slots: Optional[int] = None,
extra_hash: Optional[int] = None,
seq_id: Optional[int] = None) -> None:
"""Appends a sequence of token IDs to the existing blocks in the
BlockTable.
This method appends the given sequence of token IDs to the existing
blocks in the BlockTable. If there is not enough space in the existing
blocks, new blocks are allocated using the `ensure_num_empty_slots`
method to accommodate the additional tokens.
The token IDs are divided into chunks of size `block_size` (except for
the first chunk, which may be smaller), and each chunk is appended to a
separate block.
Args:
token_ids (List[int]): The sequence of token IDs to be appended.
num_computed_slots (Optional[int]): The number of KV cache slots
that are already filled (computed).
When sliding window is enabled, this is used to compute how many
blocks to drop at the front of the sequence.
Without sliding window, None can be passed.
Without chunked prefill, it should be the same as
_num_full_slots.
extra_hash (Optional[int]): The hash value of additional
factors such as adapters that influence the block, apart
from the token_ids.
"""
assert self._is_allocated, "no blocks have been allocated"
assert len(self._blocks) > 0
# Drop blocks that are no longer needed due to sliding window
if self._max_block_sliding_window is not None:
null_block = self._allocator.allocate_or_get_null_block()
assert num_computed_slots is not None
end_block_idx = (num_computed_slots //
self._block_size) - self._max_block_sliding_window
for idx in range(0, end_block_idx):
b = self._blocks[idx]
if b is not null_block:
self._allocator.free(b)
self._blocks[idx] = null_block
# Ensure there are enough empty slots for the new tokens plus
# lookahead slots
self.ensure_num_empty_slots(num_empty_slots=len(token_ids) +
num_lookahead_slots,
extra_hash=extra_hash,
seq_id=seq_id)
# Update the blocks with the new tokens
first_block_idx = self._num_full_slots // self._block_size
token_blocks = self._chunk_token_blocks_for_append(token_ids)
for i, token_block in enumerate(token_blocks):
self._blocks.append_token_ids(first_block_idx + i, token_block, seq_id=seq_id)
self._num_full_slots += len(token_ids)
def ensure_num_empty_slots(self,
num_empty_slots: int,
extra_hash: Optional[int] = None,
seq_id: Optional[int] = None) -> None:
"""Ensures that the BlockTable has at least the specified number of
empty slots available.
This method checks if the BlockTable has enough empty slots (i.e.,
available space) to accommodate the requested number of tokens. If not,
it allocates additional blocks on the GPU to ensure that the required
number of empty slots is available.
Args:
num_empty_slots (int): The minimum number of empty slots required.
extra_hash (Optional[int]): The hash value of additional
factors such as adapters that influence the block, apart
from the token_ids.
"""
# Currently the block table only supports
# appending tokens to GPU blocks.
device = Device.GPU
assert self._is_allocated
if self._num_empty_slots >= num_empty_slots:
return
slots_to_allocate = num_empty_slots - self._num_empty_slots
blocks_to_allocate = cdiv(slots_to_allocate, self._block_size)
for _ in range(blocks_to_allocate):
assert len(self._blocks) > 0
self._blocks.append(
self._allocator.allocate_mutable_block(
prev_block=self._blocks[-1],
device=device,
extra_hash=extra_hash,
seq_id=seq_id))
def fork(self) -> "BlockTable":
"""Creates a new BlockTable instance with a copy of the blocks from the
current instance.
This method creates a new BlockTable instance with the same block size,
block allocator, and a copy of the blocks from the current instance. The
new BlockTable has its own independent set of blocks, but shares the
same underlying memory allocation with the original BlockTable.
Returns:
BlockTable: A new BlockTable instance with a copy of the blocks from
the current instance.
"""
assert self._is_allocated
assert len(self._blocks) > 0
forked_blocks = self._allocator.fork(self._blocks[-1])
return BlockTable(
block_size=self._block_size,
block_allocator=self._allocator,
_blocks=forked_blocks,
max_block_sliding_window=self._max_block_sliding_window,
)
def free(self, seq_id: Optional[int] = None) -> None:
"""Frees the memory occupied by the blocks in the BlockTable.
This method iterates over all the blocks in the `_blocks` list and calls
the `free` method of the `_allocator` object to release the memory
occupied by each block. After freeing all the blocks, the `_blocks` list
is set to `None`.
"""
self.blocks.reverse()
for block in self.blocks:
self._allocator.free(block, seq_id=seq_id)
self._blocks.reset()
@property
def physical_block_ids(self) -> List[int]:
"""Returns a list of physical block indices for the blocks in the
BlockTable.
This property returns a list of integers, where each integer represents
the physical block index of a corresponding block in the `_blocks` list.
The physical block index is a unique identifier for the memory location
occupied by the block.
Returns:
List[int]: A list of physical block indices for the blocks in the
BlockTable.
"""
return self._blocks.ids()
def get_unseen_token_ids(self, sequence_token_ids: List[int]) -> List[int]:
"""Get the number of "unseen" tokens in the sequence.
Unseen tokens are tokens in the sequence corresponding to this block
table, but are not yet appended to this block table.
Args:
sequence_token_ids (List[int]): The list of token ids in the
sequence.
Returns:
List[int]: The postfix of sequence_token_ids that has not yet been
appended to the block table.
"""
# Since the block table is append-only, the unseen token ids are the
# ones after the appended ones.
return sequence_token_ids[self.num_full_slots:]
def _allocate_blocks_for_token_ids(
self,
prev_block: Optional[Block],
token_ids: List[int],
device: Device,
extra_hash: Optional[int] = None,
seq_id: Optional[int] = None) -> List[Block]:
blocks: List[Block] = []
block_token_ids = []
tail_token_ids = []
for cur_token_ids in chunk_list(token_ids, self._block_size):
if len(cur_token_ids) == self._block_size:
block_token_ids.append(cur_token_ids)
else:
tail_token_ids.append(cur_token_ids)
if block_token_ids:
blocks.extend(
self._allocator.allocate_immutable_blocks(
prev_block,
block_token_ids=block_token_ids,
device=device,
extra_hash=extra_hash,
seq_id=seq_id))
prev_block = blocks[-1]
if tail_token_ids:
assert len(tail_token_ids) == 1
cur_token_ids = tail_token_ids[0]
block = self._allocator.allocate_mutable_block(
prev_block=prev_block, device=device, extra_hash=extra_hash, seq_id=seq_id)
block.append_token_ids(cur_token_ids, seq_id)
blocks.append(block)
return blocks
def _get_all_token_ids(self) -> List[int]:
# NOTE: This function is O(seq_len); use sparingly.
token_ids: List[int] = []
if not self._is_allocated:
return token_ids
for block in self.blocks:
token_ids.extend(block.token_ids)
return token_ids
def _get_num_token_ids(self) -> int:
res = 0
for block in self.blocks:
res += len(block.token_ids)
return res
@property
def _is_allocated(self) -> bool:
return len(self._blocks) > 0
@property
def blocks(self) -> List[Block]:
return self._blocks.list()
@property
def _num_empty_slots(self) -> int:
assert self._is_allocated
return len(self._blocks) * self._block_size - self._num_full_slots
@property
def num_full_slots(self) -> int:
"""Returns the total number of tokens currently stored in the
BlockTable.
Returns:
int: The total number of tokens currently stored in the BlockTable.
"""
return self._num_full_slots
def get_num_blocks_touched_by_append_slots(
self, token_ids: List[int], num_lookahead_slots: int) -> int:
"""Determine how many blocks will be "touched" by appending the token
ids.
This is required for the scheduler to determine whether a sequence can
continue generation, or if it must be preempted.
"""
# Math below is equivalent to:
# all_token_ids = token_ids + [-1] * num_lookahead_slots
# token_blocks = self._chunk_token_blocks_for_append(all_token_ids)
# return len(token_blocks)
num_token_ids = len(token_ids) + num_lookahead_slots
first_chunk_size = self._block_size - (self._num_full_slots %
self._block_size)
num_token_blocks = (1 + math.ceil(
(num_token_ids - first_chunk_size) / self._block_size))
return num_token_blocks
def _chunk_token_blocks_for_append(
self, token_ids: List[int]) -> List[List[int]]:
"""Split the token ids into block-sized chunks so they can be easily
appended to blocks. The first such "token block" may have less token ids
than the block size, since the last allocated block may be partially
full.
If no token ids are provided, then no chunks are returned.
"""
if not token_ids:
return []
first_chunk_size = self._block_size - (self._num_full_slots %
self._block_size)
token_blocks = [token_ids[:first_chunk_size]]
token_blocks.extend(
chunk_list(token_ids[first_chunk_size:], self._block_size))
return token_blocks

View File

@@ -0,0 +1,21 @@
# SPDX-License-Identifier: Apache-2.0
from collections import deque
from dataclasses import dataclass
from typing import Deque, Dict, Iterable, List, Optional, Protocol, Tuple
from vllm.core.block.interfaces import Block, BlockAllocator
BlockId = int
RefCount = int
class BlockList:
def append_token_ids(self, block_index: int, token_ids: List[int], seq_id: Optional[int]=None) -> None:
block = self._blocks[block_index]
prev_block_id = block.block_id
block.append_token_ids(token_ids, seq_id=seq_id)
# CoW or promotion may update the internal block_id
if prev_block_id != block.block_id:
self._update_block_id(block_index, block.block_id)

View File

@@ -0,0 +1,373 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Dict, FrozenSet, List, Optional, Tuple
from vllm.core.block.interfaces import (Block, BlockAllocator, BlockId,
DeviceAwareBlockAllocator)
# from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator
from vllm.core.block.naive_block import NaiveBlock
from vllm_vacc.vllm.core.block.naive_block import NaiveBlockAllocator
from vllm_vacc.vllm.core.block.prefix_caching_block import PrefixCachingBlockAllocator
from vllm.platforms import current_platform
from vllm.utils import Device
from vllm.core.block.cpu_gpu_block_allocator import NullBlock
class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
"""A block allocator that can allocate blocks on both CPU and GPU memory.
This class implements the `DeviceAwareBlockAllocator` interface and provides
functionality for allocating and managing blocks of memory on both CPU and
GPU devices.
The `CpuGpuBlockAllocator` maintains separate memory pools for CPU and GPU
blocks, and allows for allocation, deallocation, forking, and swapping of
blocks across these memory pools.
"""
@staticmethod
def create(
allocator_type: str,
num_gpu_blocks: int,
num_cpu_blocks: int,
block_size: int,
) -> DeviceAwareBlockAllocator:
"""Creates a CpuGpuBlockAllocator instance with the specified
configuration.
This static method creates and returns a CpuGpuBlockAllocator instance
based on the provided parameters. It initializes the CPU and GPU block
allocators with the specified number of blocks, block size, and
allocator type.
Args:
allocator_type (str): The type of block allocator to use for CPU
and GPU blocks. Currently supported values are "naive" and
"prefix_caching".
num_gpu_blocks (int): The number of blocks to allocate for GPU
memory.
num_cpu_blocks (int): The number of blocks to allocate for CPU
memory.
block_size (int): The size of each block in number of tokens.
Returns:
DeviceAwareBlockAllocator: A CpuGpuBlockAllocator instance with the
specified configuration.
Notes:
- The block IDs are assigned contiguously, with GPU block IDs coming
before CPU block IDs.
"""
# For HPU, block id 0 is used only for padding
reserved_blocks = 1 if current_platform.is_hpu() else 0
block_ids = list(
range(reserved_blocks, num_gpu_blocks + num_cpu_blocks))
num_gpu_blocks -= reserved_blocks
gpu_block_ids = block_ids[:num_gpu_blocks]
cpu_block_ids = block_ids[num_gpu_blocks:]
if allocator_type == "naive":
gpu_allocator: BlockAllocator = NaiveBlockAllocator(
create_block=NaiveBlock, # type: ignore
num_blocks=num_gpu_blocks,
block_size=block_size,
block_ids=gpu_block_ids,
)
cpu_allocator: BlockAllocator = NaiveBlockAllocator(
create_block=NaiveBlock, # type: ignore
num_blocks=num_cpu_blocks,
block_size=block_size,
block_ids=cpu_block_ids,
)
elif allocator_type == "prefix_caching":
gpu_allocator = PrefixCachingBlockAllocator(
num_blocks=num_gpu_blocks,
block_size=block_size,
block_ids=gpu_block_ids,
)
cpu_allocator = PrefixCachingBlockAllocator(
num_blocks=num_cpu_blocks,
block_size=block_size,
block_ids=cpu_block_ids,
)
else:
raise ValueError(f"Unknown allocator type {allocator_type=}")
return CpuGpuBlockAllocator(
cpu_block_allocator=cpu_allocator,
gpu_block_allocator=gpu_allocator,
)
def __init__(self, cpu_block_allocator: BlockAllocator,
gpu_block_allocator: BlockAllocator):
assert not (
cpu_block_allocator.all_block_ids
& gpu_block_allocator.all_block_ids
), "cpu and gpu block allocators can't have intersection of block ids"
self._allocators = {
Device.CPU: cpu_block_allocator,
Device.GPU: gpu_block_allocator,
}
self._swap_mapping: Dict[int, int] = {}
self._null_block: Optional[Block] = None
self._block_ids_to_allocator: Dict[int, BlockAllocator] = {}
for _, allocator in self._allocators.items():
for block_id in allocator.all_block_ids:
self._block_ids_to_allocator[block_id] = allocator
def allocate_or_get_null_block(self) -> Block:
if self._null_block is None:
self._null_block = NullBlock(
self.allocate_mutable_block(None, Device.GPU))
return self._null_block
def allocate_mutable_block(self,
prev_block: Optional[Block],
device: Device,
extra_hash: Optional[int] = None,
seq_id: Optional[int] = None) -> Block:
"""Allocates a new mutable block on the specified device.
Args:
prev_block (Optional[Block]): The previous block to in the sequence.
Used for prefix hashing.
device (Device): The device on which to allocate the new block.
extra_hash (Optional[int]): The hash value of additional
factors, such as adapters, that influence the block hash
in the prefix caching block.
Returns:
Block: The newly allocated mutable block.
"""
return self._allocators[device].allocate_mutable_block(
prev_block, extra_hash=extra_hash, seq_id=seq_id)
def allocate_immutable_blocks(
self,
prev_block: Optional[Block],
block_token_ids: List[List[int]],
device: Device,
extra_hash: Optional[int] = None,
seq_id: Optional[int] = None) -> List[Block]:
"""Allocates a new group of immutable blocks with the provided block
token IDs on the specified device.
Args:
prev_block (Optional[Block]): The previous block in the sequence.
Used for prefix hashing.
block_token_ids (List[int]): The list of block token IDs to be
stored in the new blocks.
device (Device): The device on which to allocate the new block.
extra_hash (Optional[int]): The hash value of additional
factors, such as adapters, that influence the block hash
in the prefix caching block.
Returns:
List[Block]: The newly allocated list of immutable blocks
containing the provided block token IDs.
"""
return self._allocators[device].allocate_immutable_blocks(
prev_block, block_token_ids, extra_hash=extra_hash, seq_id=seq_id)
def allocate_immutable_block(self,
prev_block: Optional[Block],
token_ids: List[int],
device: Device,
extra_hash: Optional[int] = None) -> Block:
"""Allocates a new immutable block with the provided token IDs on the
specified device.
Args:
prev_block (Optional[Block]): The previous block in the sequence.
Used for prefix hashing.
token_ids (List[int]): The list of token IDs to be stored in the new
block.
device (Device): The device on which to allocate the new block.
extra_hash (Optional[int]): The hash value of additional
factors, such as adapters, that influence the block hash
in the prefix caching block.
Returns:
Block: The newly allocated immutable block containing the provided
token IDs.
"""
return self._allocators[device].allocate_immutable_block(
prev_block, token_ids, extra_hash=extra_hash)
def free(self, block: Block, seq_id: Optional[int] = None) -> None:
"""Frees the memory occupied by the given block.
Args:
block (Block): The block to be freed.
"""
# Null block should never be freed
if isinstance(block, NullBlock):
return
block_id = block.block_id
assert block_id is not None
allocator = self._block_ids_to_allocator[block_id]
allocator.free(block, seq_id=seq_id)
def fork(self, last_block: Block) -> List[Block]:
"""Creates a new sequence of blocks that shares the same underlying
memory as the original sequence.
Args:
last_block (Block): The last block in the original sequence.
Returns:
List[Block]: A new list of blocks that shares the same memory as the
original sequence.
"""
# do not attempt to fork the null block
assert not isinstance(last_block, NullBlock)
block_id = last_block.block_id
assert block_id is not None
allocator = self._block_ids_to_allocator[block_id]
return allocator.fork(last_block)
def get_num_free_blocks(self, device: Device, seq_id: int=None) -> int:
"""Returns the number of free blocks available on the specified device.
Args:
device (Device): The device for which to query the number of free
blocks. AssertionError is raised if None is passed.
Returns:
int: The number of free blocks available on the specified device.
"""
return self._allocators[device].get_num_free_blocks(seq_id=seq_id)
def get_num_total_blocks(self, device: Device, seq_id: int=None) -> int:
return self._allocators[device].get_num_total_blocks(seq_id=seq_id)
def get_physical_block_id(self, device: Device, absolute_id: int) -> int:
"""Returns the zero-offset block id on certain device given the
absolute block id.
Args:
device (Device): The device for which to query relative block id.
absolute_id (int): The absolute block id for the block in
whole allocator.
Returns:
int: The zero-offset block id on certain device.
"""
return self._allocators[device].get_physical_block_id(absolute_id)
def swap(self, blocks: List[Block], src_device: Device,
dst_device: Device) -> Dict[int, int]:
"""Execute the swap for the given blocks from source_device
on to dest_device, save the current swap mapping and append
them to the accumulated `self._swap_mapping` for each
scheduling move.
Args:
blocks: List of blocks to be swapped.
src_device (Device): Device to swap the 'blocks' from.
dst_device (Device): Device to swap the 'blocks' to.
Returns:
Dict[int, int]: Swap mapping from source_device
on to dest_device.
"""
src_block_ids = [block.block_id for block in blocks]
self._allocators[src_device].swap_out(blocks)
self._allocators[dst_device].swap_in(blocks)
dst_block_ids = [block.block_id for block in blocks]
current_swap_mapping: Dict[int, int] = {}
for src_block_id, dst_block_id in zip(src_block_ids, dst_block_ids):
if src_block_id is not None and dst_block_id is not None:
self._swap_mapping[src_block_id] = dst_block_id
current_swap_mapping[src_block_id] = dst_block_id
return current_swap_mapping
def get_num_full_blocks_touched(self, blocks: List[Block],
device: Device) -> int:
"""Returns the number of full blocks that will be touched by
swapping in/out the given blocks on to the 'device'.
Args:
blocks: List of blocks to be swapped.
device (Device): Device to swap the 'blocks' on.
Returns:
int: the number of full blocks that will be touched by
swapping in/out the given blocks on to the 'device'.
Non full blocks are ignored when deciding the number
of blocks to touch.
"""
return self._allocators[device].get_num_full_blocks_touched(blocks)
def clear_copy_on_writes(self, seq_id: Optional[int] = None) -> List[Tuple[int, int]]:
"""Clears the copy-on-write (CoW) state and returns the mapping of
source to destination block IDs.
Returns:
List[Tuple[int, int]]: A list mapping source block IDs to
destination block IDs.
"""
# CoW only supported on GPU
device = Device.GPU
return self._allocators[device].clear_copy_on_writes()
def mark_blocks_as_accessed(self, block_ids: List[int],
now: float) -> None:
"""Mark blocks as accessed, only use for prefix caching."""
# Prefix caching only supported on GPU.
device = Device.GPU
return self._allocators[device].mark_blocks_as_accessed(block_ids, now)
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
"""Mark blocks as accessed, only use for prefix caching."""
# Prefix caching only supported on GPU.
device = Device.GPU
return self._allocators[device].mark_blocks_as_computed(block_ids)
def get_common_computed_block_ids(
self, computed_seq_block_ids: List[List[int]]) -> List[int]:
# Prefix caching only supported on GPU.
device = Device.GPU
return self._allocators[device].get_common_computed_block_ids(
computed_seq_block_ids)
@property
def all_block_ids(self) -> FrozenSet[int]:
return frozenset(self._block_ids_to_allocator.keys())
def get_prefix_cache_hit_rate(self, device: Device) -> float:
"""Prefix cache hit rate. -1 means not supported or disabled."""
assert device in self._allocators
return self._allocators[device].get_prefix_cache_hit_rate()
def reset_prefix_cache(self) -> bool:
"""Reset prefix cache for all devices."""
success = True
for allocator in self._allocators.values():
success = success and allocator.reset_prefix_cache()
return success
def get_and_reset_swaps(self) -> List[Tuple[int, int]]:
"""Returns and clears the mapping of source to destination block IDs.
Will be called after every swapping operations for now, and after every
schedule when BlockManagerV2 become default. Currently not useful.
Returns:
List[Tuple[int, int]]: A mapping of source to destination block IDs.
"""
mapping = self._swap_mapping.copy()
self._swap_mapping.clear()
return list(mapping.items())
def find_cached_blocks_prefix(
self,
block_hashes: List[int],
device: Device = Device.GPU,
) -> List[int]:
return self._allocators[device].find_cached_blocks_prefix(block_hashes)

View File

@@ -0,0 +1,464 @@
# SPDX-License-Identifier: Apache-2.0
from collections import deque
from typing import Deque, FrozenSet, Iterable, List, Optional, Tuple, Union
from vllm.core.block.common import (BlockPool, CopyOnWriteTracker, RefCounter,
get_all_blocks_recursively)
from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device
from typing import Dict
from vllm.logger import init_logger
import os
max_seq_num = int(os.getenv("MAX_SEQ_NUM", 4))
from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size
Refcount = int
logger = init_logger(__name__)
class NaiveBlockAllocator(BlockAllocator):
"""A simple block allocator that manages blocks of memory without prefix
caching.
Args:
create_block (Block.Factory): A factory function for creating new
blocks. This is used when a NaiveBlockAllocator is composed within
a prefix caching allocator -- the naive block allocator must
construct prefix caching blocks (but shouldn't know anything else
about them).
num_blocks (int): The total number of blocks to manage.
block_size (int): The size of each block in tokens.
block_ids (Optional[Iterable[int]], optional): An optional iterable of
block IDs. If not provided, block IDs will be assigned sequentially
from 0 to num_blocks - 1.
"""
def __init__(
self,
create_block: Block.Factory,
num_blocks: int,
block_size: int,
block_ids: Optional[Iterable[int]] = None,
block_pool: Optional[BlockPool] = None,
):
# new mapping seqid : block_group_id
self.is_partitioned = False
self.num_blocks = num_blocks
self.seq_mapping: Dict[int, List[int]] = {}
if block_ids is None:
block_ids = range(num_blocks)
self._free_block_indices_all: Deque[BlockId] = deque(block_ids)
self._all_block_indices = frozenset(block_ids)
assert len(self._all_block_indices) == num_blocks
self._refcounter = RefCounter(
all_block_indices=self._free_block_indices_all)
self._block_size = block_size
self._cow_tracker = CopyOnWriteTracker(
refcounter=self._refcounter.as_readonly())
if block_pool is None:
extra_factor = 4
# Pre-allocate "num_blocks * extra_factor" block objects.
# The "* extra_factor" is a buffer to allow more block objects
# than physical blocks
self._block_pool = BlockPool(self._block_size, create_block, self,
num_blocks * extra_factor)
else:
# In this case, the block pool is provided by the caller,
# which means that there is most likely a need to share
# a block pool between allocators
self._block_pool = block_pool
# partition blocks to block groups
self.partition_blocks()
def allocate_immutable_block(self,
prev_block: Optional[Block],
token_ids: List[int],
extra_hash: Optional[int] = None,
device: Optional[Device] = None) -> Block:
"""Allocates a new immutable block with the given token IDs, linked to
the previous block.
Args:
prev_block (Optional[Block]): The previous block in the sequence. If
None, then the block to be allocated is the first block in the
sequence.
token_ids (List[int]): The token IDs to be stored in the new block.
Returns:
Block: The newly allocated immutable block.
"""
assert device is None
block = self.allocate_mutable_block(prev_block=prev_block)
block.append_token_ids(token_ids)
return block
def allocate_immutable_blocks(
self,
prev_block: Optional[Block],
block_token_ids: List[List[int]],
extra_hash: Optional[int] = None,
device: Optional[Device] = None,
seq_id: Optional[int] = None) -> List[Block]:
assert device is None
num_blocks = len(block_token_ids)
block_ids = []
for i in range(num_blocks):
block_ids.append(self._allocate_block_id(seq_id=seq_id))
blocks = []
for i in range(num_blocks):
prev_block = self._block_pool.init_block(
prev_block=prev_block,
token_ids=block_token_ids[i],
block_size=self._block_size,
physical_block_id=block_ids[i])
blocks.append(prev_block)
return blocks
def allocate_mutable_block(self,
prev_block: Optional[Block],
extra_hash: Optional[int] = None,
device: Optional[Device] = None,
seq_id: Optional[int] = None) -> Block:
"""Allocates a new mutable block, linked to the previous block.
Args:
prev_block (Optional[Block]): The previous block in the sequence. If
None, then the block to be allocated is the first block in the
sequence.
Returns:
Block: The newly allocated mutable block.
"""
assert device is None
assert seq_id is not None
block_id = self._allocate_block_id(seq_id)
block = self._block_pool.init_block(prev_block=prev_block,
token_ids=[],
block_size=self._block_size,
physical_block_id=block_id)
return block
def _allocate_block_id(self, seq_id: Optional[int] = None) -> BlockId:
assert seq_id is not None
# always use the lastest grp id to allocate
# since the previous grps are exausted
grp_id_list = self.get_group_list(seq_id)
if len(grp_id_list) == 0 or len(self._free_block_indices[grp_id_list[-1]]) == 0:
if not self._free_block_grp_indices:
# no more block id in block group pool
# should not reach here
raise False
# raise BlockAllocator.NoFreeBlocksError()
else:
# pop a new block and add to seq_mapping
grp_id = self._free_block_grp_indices.popleft()
grp_id_list.append(grp_id)
self.seq_mapping[seq_id] = grp_id_list
grp_id = grp_id_list[-1]
block_id = self._free_block_indices[grp_id].popleft()
self._refcounter.incr(block_id)
return block_id
def _free_block_id(self, block: Union[Block, BlockId], seq_id: Optional[int] = None) -> None:
assert seq_id is not None
grp_id_list = self.get_group_list(seq_id)
if isinstance(block, Block):
block_id = block.block_id
block.block_id = None
else:
block_id = block
assert block_id is not None
# block_id should always be in grp_id_list[0]
# since the block id is freed in block id order
grp_id = grp_id_list[-1]
assert block_id in self._block_grp_indices[grp_id], f"grp_id: {grp_id} block_id:{block_id}"
refcount = self._refcounter.decr(block_id)
if refcount == 0:
self._free_block_indices[grp_id].appendleft(block_id)
if len(self._free_block_indices[grp_id]) == len(self._block_grp_indices[grp_id]):
# free group
self.seq_mapping[seq_id].remove(grp_id)
if len(self.seq_mapping[seq_id]) == 0:
# free seq_id
del self.seq_mapping[seq_id]
# collect back to block group pool
self._free_block_grp_indices.appendleft(grp_id)
def free(self, block: Block, keep_block_object: bool = False, seq_id: Optional[int] = None) -> None:
# Release the physical block id
self._free_block_id(block, seq_id=seq_id)
# Release the block object
if not keep_block_object:
self._block_pool.free_block(block)
def free_block_id(self, block_id: BlockId, seq_id: Optional[int] = None) -> None:
assert seq_id is not None
self._free_block_id(block_id, seq_id)
def fork(self, last_block: Block, seq_id: Optional[int] = None) -> List[Block]:
"""Creates a new sequence of blocks that shares the same underlying
memory as the original sequence.
Args:
last_block (Block): The last block in the original sequence.
Returns:
List[Block]: The new sequence of blocks that shares the same memory
as the original sequence.
"""
source_blocks = get_all_blocks_recursively(last_block)
forked_blocks: List[Block] = []
prev_block = None
grp_id_list = self.get_group_list(seq_id)
for block in source_blocks:
# Increment refcount for each block.
assert block.block_id is not None
grp_id = self.get_group_id(block.block_id, grp_id_list)
assert grp_id != -1, "can't locate block group"
refcount = self._refcounter.incr(block.block_id)
assert refcount != 1, "can't fork free'd block"
forked_block = self._block_pool.init_block(
prev_block=prev_block,
token_ids=block.token_ids,
block_size=self._block_size,
physical_block_id=block.block_id)
forked_blocks.append(forked_block)
prev_block = forked_blocks[-1]
return forked_blocks
def partition_blocks(self) -> None:
# only parition once in each vllm server lifecycle
if self.is_partitioned: #and len(self.seq_mapping) > 0:
return
self.is_partitioned = True
self._blk_num_per_grp = env_blk_grp_size // self._block_size
self._all_blk_grp_num = self.num_blocks // self._blk_num_per_grp
block_groups = []
for i in range(self._all_blk_grp_num):
start = i * self._blk_num_per_grp
block_groups.append([k for k in range(start, start + self._blk_num_per_grp)])
self._free_block_grp_indices: Deque[BlockId] = deque(range(self._all_blk_grp_num))
# self._free_block_grp_indices: Deque[BlockId] = deque(range(self._all_blk_grp_num-1,-1,-1))
self._free_block_indices: List[Deque[BlockId]] = [deque(block_ids) for block_ids in block_groups]
self._block_grp_indices: List[FrozenSet] = [frozenset(block_ids) for block_ids in block_groups]
# self._all_block_indices =[frozenset(block_ids) for block_ids in block_groups]
# get group id list according to block_id
def get_group_id(self, block_id, grp_id_list) -> int:
for i in grp_id_list:
if block_id in self._block_grp_indices[i]:
return i
assert False
# return -1
# get group id list according to seq_id
def get_group_list(self, seq_id) -> int:
"""Get group id list acoording to current seq_id
key: seq_id, value: [grp_id, grp_id, ...]
"""
assert seq_id is not None
grp_id_list = []
if seq_id in self.seq_mapping:
grp_id_list = self.seq_mapping[seq_id]
return grp_id_list
def get_num_free_blocks(self, seq_id: Optional[int] = None) -> int:
free_blocks = len(self._free_block_grp_indices) * self._blk_num_per_grp
if seq_id is not None:
if seq_id in self.seq_mapping:
# seq_id is already allocated
grp_id_list = self.seq_mapping[seq_id]
free_blocks += len(self._free_block_indices[grp_id_list[-1]])
else:
# new seq_id
if len(self.seq_mapping) >= max_seq_num:
return 0
# means real allocate, only consider free 8K block groups
return free_blocks
else:
# for memory usage analysis / swap
# need consider also the free blocks that in 8K blocks groups
for _, grp_id_list in self.seq_mapping.items():
if len(grp_id_list) > 0:
free_blocks += len(self._free_block_indices[grp_id_list[-1]])
return free_blocks
def get_num_total_blocks(self, seq_id: Optional[int] = None) -> int:
return len(self._all_block_indices)
def get_physical_block_id(self, absolute_id: int) -> int:
"""Returns the zero-offset block id on certain block allocator
given the absolute block id.
Args:
absolute_id (int): The absolute block id for the block
in whole allocator.
Returns:
int: The zero-offset block id on certain device.
"""
return sorted(self._all_block_indices).index(absolute_id)
@property
def refcounter(self):
return self._refcounter
@property
def all_block_ids(self) -> FrozenSet[int]:
return self._all_block_indices
def cow_block_if_not_appendable(self, block: Block, seq_id: Optional[int] = None) -> BlockId:
"""Performs a copy-on-write operation on the given block if it is not
appendable.
Args:
block (Block): The block to check for copy-on-write.
Returns:
BlockId: The block index of the new block if a copy-on-write
operation was performed, or the original block index if
no copy-on-write was necessary.
"""
src_block_id = block.block_id
assert src_block_id is not None
if self._cow_tracker.is_appendable(block):
return src_block_id
self._free_block_id(block, seq_id)
trg_block_id = self._allocate_block_id()
self._cow_tracker.record_cow(src_block_id, trg_block_id)
return trg_block_id
def clear_copy_on_writes(self, seq_id: Optional[int] = None) -> List[Tuple[BlockId, BlockId]]:
"""Returns the copy-on-write source->destination mapping and clears it.
Returns:
List[Tuple[BlockId, BlockId]]: A list mapping source
block indices to destination block indices.
"""
return self._cow_tracker.clear_cows()
def mark_blocks_as_accessed(self, block_ids: List[int],
now: float) -> None:
"""Mark blocks as accessed, used in prefix caching.
Since the naive allocator does not implement prefix caching, we do
nothing.
"""
pass
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
"""Mark blocks as computed, used in prefix caching.
Since the naive allocator does not implement prefix caching, we do
nothing.
"""
pass
def get_common_computed_block_ids(
self, computed_seq_block_ids: List[List[int]]) -> List[int]:
"""Determine blocks that can be skipped in prefill.
Since the naive allocator does not support prefix caching, always return
an empty list.
"""
return []
def promote_to_immutable_block(self, block: Block) -> BlockId:
raise NotImplementedError("There is no promotion for naive blocks")
def get_num_full_blocks_touched(self, blocks: List[Block]) -> int:
"""Returns the number of full blocks that will be touched by
swapping in/out.
Args:
blocks: List of blocks to be swapped.
Returns:
int: the number of full blocks that will be touched by
swapping in/out the given blocks. Non full blocks are ignored
when deciding the number of blocks to touch.
"""
# NOTE: for naive block, we use set to eliminate common blocks among
# seqs, also we compare the empty slots in the mutable blocks with
# lookahead slots to get the number of unique new block that are
# needed.
old_block_set = set()
for block in blocks:
if block.is_full:
old_block_set.add(block)
return len(old_block_set)
def swap_out(self, blocks: List[Block], seq_id: Optional[int] = None) -> None:
for block in blocks:
self._free_block_id(block, seq_id)
def swap_in(self, blocks: List[Block]) -> None:
for block in blocks:
# Here we allocate either immutable or mutable block and then
# extract its block_id. Note that the block object is released
# and the block_id is assigned to "block" to allow reusing the
# existing "block" object
if block.is_full:
tmp_block = self.allocate_immutable_block(
prev_block=block.prev_block, token_ids=block.token_ids)
else:
tmp_block = self.allocate_mutable_block(
prev_block=block.prev_block)
tmp_block.append_token_ids(block.token_ids)
block_id = tmp_block.block_id
tmp_block.block_id = None
self._block_pool.free_block(tmp_block)
block.block_id = block_id # Assign block_id
def get_prefix_cache_hit_rate(self) -> float:
return -1
def reset_prefix_cache(self) -> bool:
"""No prefix cache for naive block allocator."""
return True
def find_cached_blocks_prefix(self, block_hashes: List[int]) -> List[int]:
# Not applicable for naive block allocator.
return []
class NaiveBlock(Block):
def append_token_ids(self, token_ids: List[int], seq_id: Optional[int] = None) -> None:
"""Appends the given token IDs to the block and performs a
copy-on-write if necessary.
Args:
token_ids (Optional[List[int]]): The token IDs to be appended
to the block.
"""
assert seq_id is not None
self._append_token_ids_no_cow(token_ids)
if self._block_id is not None:
self._block_id = (self._allocator.cow_block_if_not_appendable(
self._cow_target, seq_id))

View File

@@ -0,0 +1,942 @@
# SPDX-License-Identifier: Apache-2.0
"""Token blocks."""
import sys
from bisect import bisect_left
from os.path import commonprefix
from typing import (Callable, Dict, FrozenSet, Iterable, List, Optional, Set,
Tuple)
from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker,
get_all_blocks_recursively)
from vllm.core.block.interfaces import (Block, BlockAllocator, BlockId, Device,
DeviceAwareBlockAllocator)
from vllm.core.block.naive_block import (BlockPool, NaiveBlock)
from vllm_vacc.vllm.core.block.naive_block import NaiveBlockAllocator
from vllm.core.block.prefix_caching_block import BlockTracker, assert_prefix_caching_block_or_none
from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor
from vllm.logger import init_logger
from vllm.sequence import Sequence
logger = init_logger(__name__)
PrefixHash = int
# By default, we init our block access time as _DEFAULT_LAST_ACCESSED_TIME
# so that if we find one block is still hold _DEFAULT_LAST_ACCESSED_TIME,
# then we know this block hasn't been accessed yet.
_DEFAULT_LAST_ACCESSED_TIME = -1
logger = init_logger(__name__)
class PrefixCachingBlockAllocator(BlockAllocator):
"""A block allocator that implements prefix caching.
The PrefixCachingBlockAllocator maintains a cache of blocks based on their
content hash. It reuses blocks with the same content hash to avoid redundant
memory allocation. The allocator also supports copy-on-write operations.
Args:
num_blocks (int): The total number of blocks to manage.
block_size (int): The size of each block in tokens.
block_ids(Optional[Iterable[int]], optional): An optional iterable of
block IDs. If not provided, block IDs will be assigned sequentially
from 0 to num_blocks - 1.
"""
# Note that we use 'None' as a string here instead of None because
# as of Python 3.12, hash(None) returns a constant predictable value.
# This could possibly make it easier to find and exploit hash
# collisions. 'None' as a string will be hashed differently per process,
# but consistently within the same process. This is the same as the
# behavior of None prior to Python 3.12.
_none_hash: int = hash('None')
# Implements Block.Factory.
def __init__(
self,
num_blocks: int,
block_size: int,
block_ids: Optional[Iterable[int]] = None,
eviction_policy: EvictionPolicy = EvictionPolicy.LRU,
):
if block_ids is None:
block_ids = range(num_blocks)
self._block_size = block_size
# A mapping of prefix hash to block index. All blocks which have a
# prefix hash will be in this dict, even if they have refcount 0.
self._cached_blocks: Dict[PrefixHash, BlockId] = {}
# A list of immutable block IDs that have been touched by scheduler
# and should be marked as computed after an entire batch of sequences
# are scheduled.
self._touched_blocks: Set[BlockId] = set()
# Used to track status of each physical block id
self._block_tracker: Dict[BlockId, BlockTracker] = {}
for block_id in block_ids:
self._block_tracker[block_id] = BlockTracker()
# Pre-allocate "num_blocks * extra_factor" block objects.
# The "* extra_factor" is a buffer to allow more block objects
# than physical blocks
extra_factor = 4
self._block_pool = BlockPool(self._block_size, self._create_block,
self, num_blocks * extra_factor)
# An allocator for blocks that do not have prefix hashes.
self._hashless_allocator = NaiveBlockAllocator(
create_block=self._create_block, # type: ignore
num_blocks=num_blocks,
block_size=block_size,
block_ids=block_ids,
block_pool=self._block_pool, # Share block pool here
)
# Evitor used to maintain how we want to handle those computed blocks
# if we find memory pressure is high.
self.eviction_policy = eviction_policy
self.evictor: Evictor = make_evictor(self.eviction_policy)
# We share the refcounter between allocators. This allows us to promote
# blocks originally allocated in the hashless allocator to immutable
# blocks.
self._refcounter = self._hashless_allocator.refcounter
self._cow_tracker = CopyOnWriteTracker(
refcounter=self._refcounter.as_readonly())
self.metric_data = CacheMetricData()
def _create_block(
self,
prev_block: Optional[Block],
token_ids: List[int],
block_size: int,
allocator: BlockAllocator,
block_id: Optional[int] = None,
computed: bool = False,
extra_hash: Optional[int] = None,
) -> Block:
# Bind block to self.
allocator = self
return PrefixCachingBlock(
prev_block=prev_block,
token_ids=token_ids,
block_size=block_size,
block_id=block_id,
allocator=allocator,
computed=computed,
extra_hash=extra_hash,
)
def allocate_immutable_block(self,
prev_block: Optional[Block],
token_ids: List[int],
extra_hash: Optional[int] = None,
device: Optional[Device] = None,
seq_id: Optional[int] = None) -> Block:
"""Allocates an immutable block with the given token IDs, reusing cached
blocks if possible.
Args:
prev_block (Optional[Block]): The previous block in the sequence.
token_ids (List[int]): The token IDs to be stored in the block.
Returns:
Block: The allocated immutable block.
"""
assert device is None
assert_prefix_caching_block_or_none(prev_block)
# First, try to create a block that points to cached data
block = self._block_pool.init_block(prev_block=prev_block,
token_ids=token_ids,
block_size=self._block_size,
physical_block_id=None,
extra_hash=extra_hash)
assert block.content_hash is not None
cached_block_id = self._cached_blocks.get(block.content_hash, None)
if cached_block_id is not None:
self.metric_data.query(hit=True)
block.block_id = cached_block_id
self._incr_refcount_cached_block(block)
return block
self.metric_data.query(hit=False)
self._block_pool.free_block(block)
# No cached block => Allocate a new block
block = self.allocate_mutable_block(prev_block, extra_hash=extra_hash, seq_id=seq_id)
logger.warning(f"Teng seq_id: {seq_id} block: {block.block_id} hash: {block.content_hash}")
block.append_token_ids(token_ids, seq_id=seq_id)
return block
def allocate_immutable_blocks(
self,
prev_block: Optional[Block],
block_token_ids: List[List[int]],
extra_hash: Optional[int] = None,
device: Optional[Device] = None,
seq_id: Optional[int] = None) -> List[Block]:
blocks = []
for token_ids in block_token_ids:
prev_block = self.allocate_immutable_block(prev_block=prev_block,
token_ids=token_ids,
device=device,
extra_hash=extra_hash,
seq_id=seq_id)
blocks.append(prev_block)
return blocks
def allocate_mutable_block(self,
prev_block: Optional[Block],
extra_hash: Optional[int] = None,
device: Optional[Device] = None,
seq_id: Optional[int] = None) -> Block:
"""Allocates a mutable block. If there are no free blocks, this will
evict unused cached blocks.
Args:
prev_block (Block): The previous block in the sequence.
None is not allowed unlike it is super class.
Returns:
Block: The allocated mutable block.
"""
assert device is None
assert seq_id is not None
assert_prefix_caching_block_or_none(prev_block)
block_id = self._allocate_block_id(seq_id)
block = self._block_pool.init_block(prev_block=prev_block,
token_ids=[],
block_size=self._block_size,
physical_block_id=block_id,
extra_hash=extra_hash)
assert not block.computed
assert block.content_hash is None
return block
def _incr_refcount_cached_block(self, block: Block) -> None:
# Set this block to be "computed" since it is pointing to a
# cached block id (which was already computed)
block.computed = True
block_id = block.block_id
assert block_id is not None
refcount = self._refcounter.incr(block_id)
if refcount == 1:
# In case a cached block was evicted, restore its tracking
if block_id in self.evictor:
self.evictor.remove(block_id)
self._track_block_id(block_id, computed=True)
def _decr_refcount_cached_block(self, block: Block) -> None:
# Ensure this is immutable/cached block
assert block.content_hash is not None
block_id = block.block_id
assert block_id is not None
refcount = self._refcounter.decr(block_id)
if refcount > 0:
block.block_id = None
return
else:
assert refcount == 0
# No longer used
assert block.content_hash in self._cached_blocks
# Add the cached block to the evictor
# (This keeps the cached block around so it can be reused)
self.evictor.add(block_id, block.content_hash, block.num_tokens_total,
self._block_tracker[block_id].last_accessed)
# Stop tracking the block
self._untrack_block_id(block_id)
block.block_id = None
def _decr_refcount_hashless_block(self, block: Block, seq_id: Optional[int] = None) -> None:
block_id = block.block_id
assert block_id is not None
# We may have a fork case where block is shared,
# in which case, we cannot remove it from tracking
refcount = self._refcounter.get(block_id)
if refcount == 1:
self._untrack_block_id(block_id)
# Decrement refcount of the block_id, but do not free the block object
# itself (will be handled by the caller)
self._hashless_allocator.free(block, keep_block_object=True, seq_id=seq_id)
def _allocate_block_id(self, seq_id: Optional[int] = None) -> BlockId:
"""First tries to allocate a block id from the hashless allocator,
and if there are no blocks, then tries to evict an unused cached block.
"""
assert seq_id is not None
hashless_block_id = self._maybe_allocate_hashless_block_id(seq_id=seq_id)
if hashless_block_id is not None:
return hashless_block_id
evicted_block_id = self._maybe_allocate_evicted_block_id()
if evicted_block_id is not None:
return evicted_block_id
# No block available in hashless allocator, nor in unused cache blocks.
raise BlockAllocator.NoFreeBlocksError()
def _maybe_allocate_hashless_block_id(self, seq_id: Optional[int] = None) -> Optional[BlockId]:
try:
# Allocate mutable block and extract its block_id
block = self._hashless_allocator.allocate_mutable_block(
prev_block=None, seq_id=seq_id)
block_id = block.block_id
self._block_pool.free_block(block)
self._track_block_id(block_id, computed=False)
return block_id
except BlockAllocator.NoFreeBlocksError:
return None
def _maybe_allocate_evicted_block_id(self) -> Optional[BlockId]:
if self.evictor.num_blocks == 0:
return None
# Here we get an evicted block, which is only added
# into evictor if its ref counter is 0
# and since its content would be changed, we need
# to remove it from _cached_blocks's tracking list
block_id, content_hash_to_evict = self.evictor.evict()
# Sanity checks
assert content_hash_to_evict in self._cached_blocks
_block_id = self._cached_blocks[content_hash_to_evict]
assert self._refcounter.get(_block_id) == 0
assert _block_id == block_id
self._cached_blocks.pop(content_hash_to_evict)
self._refcounter.incr(block_id)
self._track_block_id(block_id, computed=False)
return block_id
def _free_block_id(self, block: Block, seq_id: Optional[int] = None) -> None:
"""Decrements the refcount of the block. The block may be in two
possible states: (1) immutable/cached or (2) mutable/hashless.
In the first case, the refcount is decremented directly and the block
may be possibly added to the evictor. In other case, hashless
allocator free(..) with keep_block_object=True is called to only free
the block id (since the block object may be reused by the caller)
"""
block_id = block.block_id
assert block_id is not None, "Freeing unallocated block is undefined"
if block.content_hash is not None:
# Immutable: This type of block is always cached, and we want to
# keep it in the evictor for future reuse
self._decr_refcount_cached_block(block)
else:
# Mutable: This type of block is not cached, so we release it
# directly to the hashless allocator
self._decr_refcount_hashless_block(block, seq_id=seq_id)
assert block.block_id is None
def free(self, block: Block, keep_block_object: bool = False, seq_id: Optional[int] = None) -> None:
"""Release the block (look at free_block_id(..) docs)
"""
# Release the physical block index
self._free_block_id(block, seq_id=seq_id)
# Release the block object to the pool
if not keep_block_object:
self._block_pool.free_block(block)
def fork(self, last_block: Block) -> List[Block]:
"""Creates a new sequence of blocks that shares the same underlying
memory as the original sequence.
Args:
last_block (Block): The last block in the original sequence.
Returns:
List[Block]: The new sequence of blocks that shares the same memory
as the original sequence.
"""
source_blocks = get_all_blocks_recursively(last_block)
forked_blocks: List[Block] = []
prev_block = None
for block in source_blocks:
block_id = block.block_id
assert block_id is not None
refcount = self._refcounter.incr(block_id)
assert refcount != 1, "can't fork free'd block_id = {}".format(
block_id)
forked_block = self._block_pool.init_block(
prev_block=prev_block,
token_ids=block.token_ids,
block_size=self._block_size,
physical_block_id=block_id,
extra_hash=block.extra_hash)
forked_blocks.append(forked_block)
prev_block = forked_blocks[-1]
return forked_blocks
def get_num_free_blocks(self, seq_id: Optional[int] = None, device: Optional[Device] = None) -> int:
assert device is None
# The number of free blocks is the number of hashless free blocks
# plus the number of blocks evictor could free from its list.
return self._hashless_allocator.get_num_free_blocks(seq_id=seq_id
) + self.evictor.num_blocks
def get_num_total_blocks(self, seq_id: Optional[int] = None) -> int:
return self._hashless_allocator.get_num_total_blocks(seq_id=seq_id)
def get_physical_block_id(self, absolute_id: int) -> int:
"""Returns the zero-offset block id on certain block allocator
given the absolute block id.
Args:
absolute_id (int): The absolute block id for the block
in whole allocator.
Returns:
int: The rzero-offset block id on certain device.
"""
return sorted(self.all_block_ids).index(absolute_id)
@property
def all_block_ids(self) -> FrozenSet[int]:
return self._hashless_allocator.all_block_ids
def get_prefix_cache_hit_rate(self) -> float:
return self.metric_data.get_hit_rate()
def reset_prefix_cache(self) -> bool:
"""Reset prefix cache. This function may be used in RLHF
flows to invalid prefix caching after the weights are updated,
or used for resetting prefix caching status for benchmarking.
Returns:
bool: True if the prefix cache is successfully reset,
False otherwise.
"""
num_used_blocks = (self.get_num_total_blocks() -
self.get_num_free_blocks())
if num_used_blocks > 0:
logger.warning(
"Failed to reset prefix cache because some "
"blocks (%d) are not freed yet", num_used_blocks)
return False
# Free all blocks in the evictor.
while (block_id :=
self._maybe_allocate_evicted_block_id()) is not None:
# TODO: Teng
self._hashless_allocator.free_block_id(block_id)
# Should not have any cached blocks because all blocks are evicted.
assert not self._cached_blocks
# Reset the evictor.
self.evictor = make_evictor(self.eviction_policy)
# Reset the block tracker.
for block_id in self._block_tracker:
self._block_tracker[block_id] = BlockTracker()
# Reset the metrics.
self.metric_data = CacheMetricData()
logger.info("Successfully reset prefix cache")
return True
def is_block_cached(self, block: Block) -> bool:
assert block.content_hash is not None
return block.content_hash in self._cached_blocks
def promote_to_immutable_block(self, block: Block, seq_id: Optional[int] = None) -> BlockId:
"""Once a mutable block is full, it can be promoted to an immutable
block. This means that its content can be referenced by future blocks
having the same prefix.
Note that if we already have a cached block with the same content, we
will replace the newly-promoted block's mapping with the existing cached
block id.
Args:
block: The mutable block to be promoted.
Returns:
BlockId: Either the original block index, or the block index of
the previously cached block matching the same content.
"""
# Ensure block can be promoted
assert block.content_hash is not None
assert block.block_id is not None
assert self._refcounter.get(block.block_id) > 0
if block.content_hash not in self._cached_blocks:
# No cached content hash => Set this block as cached.
# Note that this block cannot be marked as computed yet
# because other sequences in the same batch cannot reuse
# this block.
self._cached_blocks[block.content_hash] = block.block_id
# Mark this block as touched so that it can be marked as
# computed after the entire batch of sequences are scheduled.
self._touched_blocks.add(block.block_id)
return block.block_id
# Reuse the cached content hash
self._decr_refcount_hashless_block(block, seq_id=seq_id)
block.block_id = self._cached_blocks[block.content_hash]
# Increment refcount of the cached block and (possibly) restore
# it from the evictor.
# Note that in this case, the block is marked as computed
self._incr_refcount_cached_block(block)
return block.block_id
def cow_block_if_not_appendable(self, block: Block, seq_id: Optional[int] = None) -> BlockId:
"""Performs a copy-on-write operation on the given block if it is not
appendable.
Args:
block (Block): The block to check for copy-on-write.
Returns:
BlockId: The block index of the new block if a copy-on-write
operation was performed, or the original block index if
no copy-on-write was necessary.
"""
src_block_id = block.block_id
assert src_block_id is not None
if self._cow_tracker.is_appendable(block):
return src_block_id
self._free_block_id(block)
trg_block_id = self._allocate_block_id()
self._cow_tracker.record_cow(src_block_id, trg_block_id)
return trg_block_id
def clear_copy_on_writes(self) -> List[Tuple[BlockId, BlockId]]:
"""Returns the copy-on-write source->destination mapping and clears it.
Returns:
List[Tuple[BlockId, BlockId]]: A list mapping source
block indices to destination block indices.
"""
return self._cow_tracker.clear_cows()
def mark_blocks_as_accessed(self, block_ids: List[int],
now: float) -> None:
"""Mark blocks as accessed, used in prefix caching.
If the block is added into evictor, we need to update corresponding
info in evictor's metadata.
"""
for block_id in block_ids:
if self._block_tracker[block_id].active:
self._block_tracker[block_id].last_accessed = now
elif block_id in self.evictor:
self.evictor.update(block_id, now)
else:
raise ValueError(
"Mark block as accessed which is not belonged to GPU")
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
# Mark all touched blocks as computed.
for block_id in self._touched_blocks:
self._block_tracker[block_id].computed = True
self._touched_blocks.clear()
def _track_block_id(self, block_id: Optional[BlockId],
computed: bool) -> None:
assert block_id is not None
self._block_tracker[block_id].enable()
self._block_tracker[block_id].computed = computed
def _untrack_block_id(self, block_id: Optional[BlockId]) -> None:
assert block_id is not None
self._block_tracker[block_id].disable()
def block_is_computed(self, block_id: int) -> bool:
if self._block_tracker[block_id].active:
return self._block_tracker[block_id].computed
else:
return block_id in self.evictor
def get_common_computed_block_ids(
self, computed_seq_block_ids: List[List[int]]) -> List[int]:
"""Return the block ids that are common for a given sequence group.
Only those blocks that are immutable and already be marked
compyted would be taken consideration.
"""
# NOTE We exclude the last block to avoid the case where the entire
# prompt is cached. This would cause erroneous behavior in model
# runner.
# It returns a list of int although type annotation says list of string.
if len(computed_seq_block_ids) == 1:
return computed_seq_block_ids[0]
return commonprefix([
ids for ids in computed_seq_block_ids # type: ignore
if ids
])
def get_num_full_blocks_touched(self, blocks: List[Block]) -> int:
"""Returns the number of full blocks that will be touched by
swapping in/out.
Args:
blocks: List of blocks to be swapped.
Returns:
int: the number of full blocks that will be touched by
swapping in/out the given blocks. Non full blocks are ignored
when deciding the number of blocks to touch.
"""
num_touched_blocks: int = 0
for block in blocks:
# If the block has a match in the cache and the cached
# block is not referenced, then we still count it as a
# touched block
if block.is_full and (not self.is_block_cached(block) or \
(block.content_hash is not None and \
self._cached_blocks[block.content_hash] in \
self.evictor)):
num_touched_blocks += 1
return num_touched_blocks
def swap_out(self, blocks: List[Block]) -> None:
"""Execute the swap out actions. Basically just free the
given blocks.
Args:
blocks: List of blocks to be swapped out.
"""
for block in blocks:
self._free_block_id(block)
def swap_in(self, blocks: List[Block]) -> None:
"""Execute the swap in actions. Change the block id from
old allocator to current allocator for each block to finish
the block table update.
Args:
blocks: List of blocks to be swapped in.
"""
for block in blocks:
# Here we allocate either immutable or mutable block and then
# extract its block_id. Note that the block object is released
# and the block_id is assigned to "block" to allow reusing the
# existing "block" object
if block.is_full:
tmp_block = self.allocate_immutable_block(
prev_block=block.prev_block,
token_ids=block.token_ids,
extra_hash=block.extra_hash)
else:
tmp_block = self.allocate_mutable_block(
prev_block=block.prev_block, extra_hash=block.extra_hash)
tmp_block.append_token_ids(block.token_ids)
block_id = tmp_block.block_id
self._block_pool.free_block(tmp_block)
block.block_id = block_id # Assign block_id
def find_cached_blocks_prefix(self, block_hashes: List[int]) -> List[int]:
"""
Given a list of block hashes, return the prefix of the block hashes that
are all cached.
Since a block's block hash includes the hashes of all previous blocks,
and we only allocate/deallocate blocks in the entire sequence, so if a
block is cached, then all previous blocks are also cached. With this
property, we can use binary search to find the prefix of cached blocks.
Args:
block_hashes (List[int]): The list of block hashes.
Returns:
List[int]: The prefix of the `block_hashes` that are cached.
"""
def _block_is_cached(block_hash: PrefixHash) -> bool:
if block_hash not in self._cached_blocks:
return False
cached_block_id = self._cached_blocks[block_hash]
# We only consider the blocks that are marked as computed.
return self.block_is_computed(cached_block_id)
def _bisect_left(a, x, key: Callable[[PrefixHash], bool]) -> int:
# python <= 3.10 don't have the key argument
if sys.version_info < (3, 10):
a = [key(e) for e in a]
return bisect_left(a, x)
else:
return bisect_left(a, x, key=key)
# Look for the first block that's not cached, and returns the prefix
# i.e. blocks that are cached.
idx = _bisect_left(block_hashes,
True,
key=lambda x: not _block_is_cached(x))
return block_hashes[:idx]
class PrefixCachingBlock(Block):
"""A block implementation that supports prefix caching.
The PrefixCachingBlock class represents a block of token IDs with prefix
caching capabilities. It wraps a NaiveBlock internally and provides
additional functionality for content hashing and promoting immutable blocks
with the prefix caching allocator.
Args:
prev_block (Optional[PrefixCachingBlock]): The previous block in the
sequence.
token_ids (List[int]): The initial token IDs to be stored in the block.
block_size (int): The maximum number of token IDs that can be stored in
the block.
allocator (BlockAllocator): The prefix
caching block allocator associated with this block.
block_id (Optional[int], optional): The physical block index
of this block. Defaults to None.
extra_hash (Optional[int]): The hash value of additional factors
such as adapters that influence the block, apart from the token_ids.
"""
# Note that we use 'None' as a string here instead of None because
# as of Python 3.12, hash(None) returns a constant predictable value.
# This could possibly make it easier to find and exploit hash
# collisions. 'None' as a string will be hashed differently per process,
# but consistently within the same process. This is the same as the
# behavior of None prior to Python 3.12.
_none_hash: int = hash('None')
def __init__(
self,
prev_block: Optional[Block],
token_ids: List[int],
block_size: int,
allocator: BlockAllocator,
block_id: Optional[int] = None,
computed: bool = False,
extra_hash: Optional[int] = None,
):
assert isinstance(allocator, PrefixCachingBlockAllocator), (
"Currently this class is only tested with "
"PrefixCachingBlockAllocator. Got instead allocator = {}".format(
allocator))
assert_prefix_caching_block_or_none(prev_block)
self._prev_block = prev_block
self._cached_content_hash: Optional[int] = None
self._cached_num_tokens_total: int = 0
self._allocator = allocator
self._last_accessed: float = _DEFAULT_LAST_ACCESSED_TIME
self._computed = computed
self._extra_hash = extra_hash
# On the first time, we create the block object, and next we only
# reinitialize it
if hasattr(self, "_block"):
self._block.__init__( # type: ignore[has-type]
prev_block=prev_block,
token_ids=token_ids,
block_size=block_size,
block_id=block_id,
allocator=self._allocator)
else:
self._block = NaiveBlock(prev_block=prev_block,
token_ids=token_ids,
block_size=block_size,
block_id=block_id,
allocator=self._allocator)
self._update_num_tokens_total()
def _update_num_tokens_total(self):
"""Incrementally computes the number of tokens that there is
till the current block (included)
"""
res = 0
# Add all previous blocks
if self._prev_block is not None:
res += self._prev_block.num_tokens_total
# Add current block
res += len(self.token_ids)
self._cached_num_tokens_total = res
@property
def computed(self) -> bool:
return self._computed
@computed.setter
def computed(self, value) -> None:
self._computed = value
@property
def last_accessed(self) -> float:
return self._last_accessed
@last_accessed.setter
def last_accessed(self, last_accessed_ts: float):
self._last_accessed = last_accessed_ts
def append_token_ids(self, token_ids: List[int], seq_id: Optional[int] = None) -> None:
"""Appends the given token IDs to the block and registers the block as
immutable if the block becomes full.
Args:
token_ids (List[int]): The token IDs to be appended to the block.
"""
# Ensure this is mutable block (not promoted)
assert self.content_hash is None
assert not self.computed
assert seq_id is not None
if len(token_ids) == 0:
return
# Ensure there are input tokens
assert token_ids, "Got token_ids = {}".format(token_ids)
# Naive block handles CoW.
self._block.append_token_ids(token_ids, seq_id=seq_id)
self._update_num_tokens_total()
# If the content hash is present, then the block can be made immutable.
# Register ourselves with the allocator, potentially replacing the
# physical block index.
if self.content_hash is not None:
self.block_id = self._allocator.promote_to_immutable_block(self, seq_id=seq_id)
@property
def block_id(self) -> Optional[int]:
return self._block.block_id
@block_id.setter
def block_id(self, value) -> None:
self._block.block_id = value
@property
def is_full(self) -> bool:
return self._block.is_full
@property
def num_empty_slots(self) -> int:
return self._block.num_empty_slots
@property
def num_tokens_total(self) -> int:
return self._cached_num_tokens_total
@property
def block_size(self) -> int:
return self._block.block_size
@property
def token_ids(self) -> List[int]:
return self._block.token_ids
@property
def prev_block(self) -> Optional[Block]:
return self._prev_block
@property
def extra_hash(self) -> Optional[int]:
return self._extra_hash
@property
def content_hash(self) -> Optional[int]:
"""Return the content-based hash of the current block, or None if it is
not yet defined.
For the content-based hash to be defined, the current block must be
full.
"""
# If the hash is already computed, return it.
if self._cached_content_hash is not None:
return self._cached_content_hash
# We cannot compute a hash for the current block because it is not full.
if not self.is_full:
return None
is_first_block = self._prev_block is None
prev_block_hash = (
self._none_hash if is_first_block else
self._prev_block.content_hash # type: ignore
)
# Previous block exists but does not yet have a hash.
# Return no hash in this case.
if prev_block_hash == self._none_hash and not is_first_block:
return None
self._cached_content_hash = PrefixCachingBlock.hash_block_tokens(
is_first_block,
prev_block_hash,
cur_block_token_ids=self.token_ids,
extra_hash=self._extra_hash)
return self._cached_content_hash
@classmethod
def hash_block_tokens(cls,
is_first_block: bool,
prev_block_hash: Optional[int],
cur_block_token_ids: List[int],
extra_hash: Optional[int] = None) -> int:
"""Computes a hash value corresponding to the contents of a block and
the contents of the preceding block(s). The hash value is used for
prefix caching.
Parameters:
- is_first_block (bool): A flag indicating if the block is the first in
the sequence.
- prev_block_hash (Optional[int]): The hash of the previous block. None
if this is the first block.
- cur_block_token_ids (List[int]): A list of token ids in the current
block. The current block is assumed to be full.
- extra_hash (Optional[int]): The hash value of additional factors
such as adapters that influence the block, apart from the token_ids.
Returns:
- int: The computed hash value for the block.
"""
if is_first_block and prev_block_hash is None:
prev_block_hash = cls._none_hash
return hash((is_first_block, prev_block_hash, *cur_block_token_ids,
extra_hash))

View File

@@ -0,0 +1,575 @@
# SPDX-License-Identifier: Apache-2.0
"""A block manager that manages token blocks."""
from typing import Dict, List, Optional
from typing import Sequence as GenericSequence
from typing import Tuple
from vllm.core.block.block_table import BlockTable
from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator
from vllm.core.block.interfaces import Block
from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker,
LastAccessBlocksTracker)
from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
from vllm.utils import Device
import os
SeqId = int
EncoderSeqId = str
import logging
logger = logging.getLogger(__name__)
from vllm_vacc.vllm.model_executor.models.vars import LLM_MAX_PREFILL_SEQ_LEN
max_seq_num = int(os.getenv("MAX_SEQ_NUM", 4))
if max_seq_num not in [1, 2, 4]:
max_seq_num = 4
from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size
class SelfAttnBlockSpaceManager(BlockSpaceManager):
"""BlockSpaceManager which manages the allocation of KV cache.
It owns responsibility for allocation, swapping, allocating memory for
autoregressively-generated tokens, and other advanced features such as
prefix caching, forking/copy-on-write, and sliding-window memory allocation.
This class implements the design described in
https://github.com/vllm-project/vllm/pull/3492.
Lookahead slots
The block manager has the notion of a "lookahead slot". These are slots
in the KV cache that are allocated for a sequence. Unlike the other
allocated slots, the content of these slots is undefined -- the worker
may use the memory allocations in any way.
In practice, a worker could use these lookahead slots to run multiple
forward passes for a single scheduler invocation. Each successive
forward pass would write KV activations to the corresponding lookahead
slot. This allows low inter-token latency use-cases, where the overhead
of continuous batching scheduling is amortized over >1 generated tokens.
Speculative decoding uses lookahead slots to store KV activations of
proposal tokens.
See https://github.com/vllm-project/vllm/pull/3250 for more information
on lookahead scheduling.
Args:
block_size (int): The size of each memory block.
num_gpu_blocks (int): The number of memory blocks allocated on GPU.
num_cpu_blocks (int): The number of memory blocks allocated on CPU.
watermark (float, optional): The threshold used for memory swapping.
Defaults to 0.01.
sliding_window (Optional[int], optional): The size of the sliding
window. Defaults to None.
enable_caching (bool, optional): Flag indicating whether caching is
enabled. Defaults to False.
"""
def __init__(
self,
block_size: int,
num_gpu_blocks: int,
num_cpu_blocks: int,
watermark: float = 0.01,
sliding_window: Optional[int] = None,
enable_caching: bool = False,
) -> None:
self.block_size = block_size
self.num_total_gpu_blocks = num_gpu_blocks
self.num_total_cpu_blocks = num_cpu_blocks
self.per_gpu_blocks = num_gpu_blocks // max_seq_num
self.sliding_window = sliding_window
# max_block_sliding_window is the max number of blocks that need to be
# allocated
self.max_block_sliding_window = None
if sliding_window is not None:
# +1 here because // rounds down
num_blocks = sliding_window // block_size + 1
# +1 here because the last block may not be full,
# and so the sequence stretches one more block at the beginning
# For example, if sliding_window is 3 and block_size is 4,
# we may need 2 blocks when the second block only holds 1 token.
self.max_block_sliding_window = num_blocks + 1
self.watermark = watermark
assert watermark >= 0.0
self.enable_caching = enable_caching
# self.watermark_blocks = 1 # for test
self.watermark_blocks = int(watermark * self.per_gpu_blocks)
self.block_allocator = CpuGpuBlockAllocator.create(
allocator_type="prefix_caching" if enable_caching else "naive",
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks,
block_size=block_size,
)
self.block_tables: Dict[SeqId, BlockTable] = {}
self.cross_block_tables: Dict[EncoderSeqId, BlockTable] = {}
self._computed_blocks_tracker = ComputedBlocksTracker(
self.block_allocator, self.block_size, self.enable_caching)
self._last_access_blocks_tracker = LastAccessBlocksTracker(
self.block_allocator)
def can_allocate(self,
seq_group: SequenceGroup,
num_lookahead_slots: int = 0) -> AllocStatus:
# FIXME(woosuk): Here we assume that all sequences in the group share
# the same prompt. This may not be true for preempted sequences.
check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group)
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
num_required_blocks = BlockTable.get_num_required_blocks(
seq.get_token_ids(),
block_size=self.block_size,
num_lookahead_slots=num_lookahead_slots,
)
if seq_group.is_encoder_decoder():
encoder_seq = seq_group.get_encoder_seq()
assert encoder_seq is not None
num_required_blocks += BlockTable.get_num_required_blocks(
encoder_seq.get_token_ids(),
block_size=self.block_size,
)
if self.max_block_sliding_window is not None:
num_required_blocks = min(num_required_blocks,
self.max_block_sliding_window)
# TODO
# limitations to be removed later
required_size = num_required_blocks * self.block_size
if required_size > LLM_MAX_PREFILL_SEQ_LEN:
logging.warning(
f"This model's maximum input seq length limit is "
f"{LLM_MAX_PREFILL_SEQ_LEN} tokens. However, you requested "
f"({required_size} in the input messages, "
f"Please reduce the length of the input messages.")
return AllocStatus.NEVER
# Use watermark to avoid frequent cache eviction.
# NOTE:
# num of the gpu blocks for each seq_id might not be the same
# since each seq can use different blk group number
total_gpu_blocks = self.block_allocator.get_num_total_blocks(device=Device.GPU, seq_id=seq.seq_id)
if (total_gpu_blocks
- num_required_blocks < self.watermark_blocks):
self.block_allocator.get_num_total_blocks(device=Device.GPU, seq_id=seq.seq_id)
return AllocStatus.NEVER
# NOTE: num_required_blocks should be up aligned to 8K beforce compare
block_num = env_blk_grp_size // self.block_size
if num_required_blocks % block_num: # align
num_required_blocks = (num_required_blocks // block_num + 1) * block_num
if total_gpu_blocks % block_num:
total_gpu_blocks = (total_gpu_blocks // block_num) * block_num
# Use the aligned memory size to decide whether to reject the request.
if total_gpu_blocks - num_required_blocks < self.watermark_blocks:
logging.warning("gpu memory may not enough, please try shorter sequence")
return AllocStatus.NEVER
num_free_gpu_blocks = self.block_allocator.get_num_free_blocks(
device=Device.GPU, seq_id=seq.seq_id)
# logging.warning(f"free blocks: {num_free_gpu_blocks} required: {num_required_blocks} watermark: {self.watermark_blocks}")
if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks:
return AllocStatus.OK
else:
# logging.warning(f"free_blocks:{num_free_gpu_blocks} "
# f"required_blocks:{num_required_blocks} "
# f"watermark:{self.watermark_blocks} "
# f"allocate seq:{seq.seq_id} later")
return AllocStatus.LATER
def _allocate_sequence(self, seq: Sequence) -> BlockTable:
block_table = BlockTable(
block_size=self.block_size,
block_allocator=self.block_allocator,
max_block_sliding_window=self.max_block_sliding_window,
)
if seq.get_token_ids():
# NOTE: If there are any factors affecting the block besides
# token_ids, they should be added as input to extra_hash.
extra_hash = seq.extra_hash()
# Add blocks to the block table only if the sequence is non empty.
block_table.allocate(token_ids=seq.get_token_ids(),
extra_hash=extra_hash,
seq_id=seq.seq_id)
return block_table
def allocate(self, seq_group: SequenceGroup) -> None:
# Allocate self-attention block tables for decoder sequences
waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
assert not (set(seq.seq_id for seq in waiting_seqs)
& self.block_tables.keys()), "block table already exists"
# NOTE: Here we assume that all sequences in the group have the same
# prompt.
seq = waiting_seqs[0]
block_table: BlockTable = self._allocate_sequence(seq)
self.block_tables[seq.seq_id] = block_table
# Track seq
self._last_access_blocks_tracker.add_seq(seq.seq_id)
# Assign the block table for each sequence.
for seq in waiting_seqs[1:]:
self.block_tables[seq.seq_id] = block_table.fork()
# Track seq
self._last_access_blocks_tracker.add_seq(seq.seq_id)
# Allocate cross-attention block table for encoder sequence
#
# NOTE: Here we assume that all sequences in the group have the same
# encoder prompt.
request_id = seq_group.request_id
assert (request_id
not in self.cross_block_tables), \
"block table already exists"
check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group)
if seq_group.is_encoder_decoder():
encoder_seq = seq_group.get_encoder_seq()
assert encoder_seq is not None
block_table = self._allocate_sequence(encoder_seq)
self.cross_block_tables[request_id] = block_table
def can_append_slots(self, seq_group: SequenceGroup,
num_lookahead_slots: int) -> bool:
"""Determine if there is enough space in the GPU KV cache to continue
generation of the specified sequence group.
We use a worst-case heuristic: assume each touched block will require a
new allocation (either via CoW or new block). We can append slots if the
number of touched blocks is less than the number of free blocks.
"Lookahead slots" are slots that are allocated in addition to the slots
for known tokens. The contents of the lookahead slots are not defined.
This is used by speculative decoding when speculating future tokens.
"""
num_touched_blocks = 0
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
block_table = self.block_tables[seq.seq_id]
num_touched_blocks += (
block_table.get_num_blocks_touched_by_append_slots(
token_ids=block_table.get_unseen_token_ids(
seq.get_token_ids()),
num_lookahead_slots=num_lookahead_slots,
))
num_free_gpu_blocks = self.block_allocator.get_num_free_blocks(
Device.GPU, seq_id=seq.seq_id)
# NOTE: if False, trigger RECOMPUTE
return num_touched_blocks <= num_free_gpu_blocks
def append_slots(
self,
seq: Sequence,
num_lookahead_slots: int,
) -> List[Tuple[int, int]]:
block_table = self.block_tables[seq.seq_id]
block_table.append_token_ids(
token_ids=block_table.get_unseen_token_ids(seq.get_token_ids()),
num_lookahead_slots=num_lookahead_slots,
num_computed_slots=seq.data.get_num_computed_tokens(),
extra_hash=seq.extra_hash(),
seq_id = seq.seq_id
)
# Return any new copy-on-writes.
new_cows = self.block_allocator.clear_copy_on_writes(seq.seq_id)
return new_cows
def free(self, seq: Sequence) -> None:
seq_id = seq.seq_id
if seq_id not in self.block_tables:
# Already freed or haven't been scheduled yet.
return
# Update seq block ids with the latest access time
self._last_access_blocks_tracker.update_seq_blocks_last_access(
seq_id, self.block_tables[seq.seq_id].physical_block_ids)
# Untrack seq
self._last_access_blocks_tracker.remove_seq(seq_id)
self._computed_blocks_tracker.remove_seq(seq_id)
# Free table/blocks
self.block_tables[seq_id].free(seq_id)
del self.block_tables[seq_id]
def remove_seq_from_computed_blocks_tracker(self, seq: Sequence) -> None:
seq_id = seq.seq_id
self._computed_blocks_tracker.remove_seq(seq_id)
def free_cross(self, seq_group: SequenceGroup) -> None:
request_id = seq_group.request_id
if request_id not in self.cross_block_tables:
# Already freed or hasn't been scheduled yet.
return
self.cross_block_tables[request_id].free()
del self.cross_block_tables[request_id]
def get_block_table(self, seq: Sequence) -> List[int]:
block_ids = self.block_tables[seq.seq_id].physical_block_ids
return block_ids # type: ignore
def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]:
request_id = seq_group.request_id
assert request_id in self.cross_block_tables
block_ids = self.cross_block_tables[request_id].physical_block_ids
assert all(b is not None for b in block_ids)
return block_ids # type: ignore
def access_all_blocks_in_seq(self, seq: Sequence, now: float):
if self.enable_caching:
# Record the latest access time for the sequence. The actual update
# of the block ids is deferred to the sequence free(..) call, since
# only during freeing of block ids, the blocks are actually added to
# the evictor (which is when the most updated time is required)
# (This avoids expensive calls to mark_blocks_as_accessed(..))
self._last_access_blocks_tracker.update_last_access(
seq.seq_id, now)
def mark_blocks_as_computed(self, seq_group: SequenceGroup,
token_chunk_size: int):
# If prefix caching is enabled, mark immutable blocks as computed
# right after they have been scheduled (for prefill). This assumes
# the scheduler is synchronous so blocks are actually computed when
# scheduling the next batch.
self.block_allocator.mark_blocks_as_computed([])
def get_common_computed_block_ids(
self, seqs: List[Sequence]) -> GenericSequence[int]:
"""Determine which blocks for which we skip prefill.
With prefix caching we can skip prefill for previously-generated blocks.
Currently, the attention implementation only supports skipping cached
blocks if they are a contiguous prefix of cached blocks.
This method determines which blocks can be safely skipped for all
sequences in the sequence group.
"""
computed_seq_block_ids = []
for seq in seqs:
all_blocks = self.block_tables[seq.seq_id].physical_block_ids
num_cached_tokens = (
self._computed_blocks_tracker.get_num_cached_tokens(seq))
assert num_cached_tokens % self.block_size == 0
num_cached_blocks = num_cached_tokens // self.block_size
computed_block_ids = all_blocks[:num_cached_blocks]
computed_seq_block_ids.append(computed_block_ids)
# NOTE(sang): This assumes seq_block_ids doesn't contain any None.
return self.block_allocator.get_common_computed_block_ids(
computed_seq_block_ids) # type: ignore
def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
if parent_seq.seq_id not in self.block_tables:
# Parent sequence has either been freed or never existed.
return
src_block_table = self.block_tables[parent_seq.seq_id]
self.block_tables[child_seq.seq_id] = src_block_table.fork()
# Track child seq
self._last_access_blocks_tracker.add_seq(child_seq.seq_id)
def can_swap_in(self, seq_group: SequenceGroup,
num_lookahead_slots: int) -> AllocStatus:
"""Returns the AllocStatus for the given sequence_group
with num_lookahead_slots.
Args:
sequence_group (SequenceGroup): The sequence group to swap in.
num_lookahead_slots (int): Number of lookahead slots used in
speculative decoding, default to 0.
Returns:
AllocStatus: The AllocStatus for the given sequence group.
"""
return self._can_swap(seq_group, Device.GPU, SequenceStatus.SWAPPED,
num_lookahead_slots)
def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
"""Returns the block id mapping (from CPU to GPU) generated by
swapping in the given seq_group with num_lookahead_slots.
Args:
seq_group (SequenceGroup): The sequence group to swap in.
Returns:
List[Tuple[int, int]]: The mapping of swapping block from CPU
to GPU.
"""
physical_block_id_mapping = []
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
blocks = self.block_tables[seq.seq_id].blocks
if len(blocks) == 0:
continue
seq_swap_mapping = self.block_allocator.swap(blocks=blocks,
src_device=Device.CPU,
dst_device=Device.GPU)
# Refresh the block ids of the table (post-swap)
self.block_tables[seq.seq_id].update(blocks)
seq_physical_block_id_mapping = {
self.block_allocator.get_physical_block_id(
Device.CPU, cpu_block_id):
self.block_allocator.get_physical_block_id(
Device.GPU, gpu_block_id)
for cpu_block_id, gpu_block_id in seq_swap_mapping.items()
}
physical_block_id_mapping.extend(
list(seq_physical_block_id_mapping.items()))
return physical_block_id_mapping
def can_swap_out(self, seq_group: SequenceGroup) -> bool:
"""Returns whether we can swap out the given sequence_group
with num_lookahead_slots.
Args:
seq_group (SequenceGroup): The sequence group to swap out.
num_lookahead_slots (int): Number of lookahead slots used in
speculative decoding, default to 0.
Returns:
bool: Whether it's possible to swap out current sequence group.
"""
alloc_status = self._can_swap(seq_group, Device.CPU,
SequenceStatus.RUNNING)
return alloc_status == AllocStatus.OK
def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
"""Returns the block id mapping (from GPU to CPU) generated by
swapping out the given sequence_group with num_lookahead_slots.
Args:
sequence_group (SequenceGroup): The sequence group to swap out.
Returns:
List[Tuple[int, int]]: The mapping of swapping block from
GPU to CPU.
"""
physical_block_id_mapping = []
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
blocks = self.block_tables[seq.seq_id].blocks
if len(blocks) == 0:
continue
seq_swap_mapping = self.block_allocator.swap(blocks=blocks,
src_device=Device.GPU,
dst_device=Device.CPU)
# Refresh the block ids of the table (post-swap)
self.block_tables[seq.seq_id].update(blocks)
seq_physical_block_id_mapping = {
self.block_allocator.get_physical_block_id(
Device.GPU, gpu_block_id):
self.block_allocator.get_physical_block_id(
Device.CPU, cpu_block_id)
for gpu_block_id, cpu_block_id in seq_swap_mapping.items()
}
physical_block_id_mapping.extend(
list(seq_physical_block_id_mapping.items()))
return physical_block_id_mapping
def get_num_free_gpu_blocks(self) -> int:
return self.block_allocator.get_num_free_blocks(Device.GPU)
def get_num_free_cpu_blocks(self) -> int:
return self.block_allocator.get_num_free_blocks(Device.CPU)
def get_prefix_cache_hit_rate(self, device: Device) -> float:
return self.block_allocator.get_prefix_cache_hit_rate(device)
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
return self.block_allocator.reset_prefix_cache(device)
def _can_swap(self,
seq_group: SequenceGroup,
device: Device,
status: SequenceStatus,
num_lookahead_slots: int = 0) -> AllocStatus:
"""Returns the AllocStatus for swapping in/out the given sequence_group
on to the 'device'.
Args:
sequence_group (SequenceGroup): The sequence group to swap in/out.
device (Device): device to swap the 'seq_group' on.
status (SequenceStatus): The status of sequence which is needed
for action. RUNNING for swap out and SWAPPED for swap in
num_lookahead_slots (int): Number of lookahead slots used in
speculative decoding, default to 0.
Returns:
AllocStatus: The AllocStatus for swapping in/out the given
sequence_group on to the 'device'.
"""
# First determine the number of blocks that will be touched by this
# swap. Then verify if there are available blocks in the device
# to perform the swap.
num_blocks_touched = 0
blocks: List[Block] = []
for seq in seq_group.get_seqs(status=status):
block_table = self.block_tables[seq.seq_id]
if block_table.blocks is not None:
# Compute the number blocks to touch for the tokens to be
# appended. This does NOT include the full blocks that need
# to be touched for the swap.
num_blocks_touched += \
block_table.get_num_blocks_touched_by_append_slots(
block_table.get_unseen_token_ids(seq.get_token_ids()),
num_lookahead_slots=num_lookahead_slots)
blocks.extend(block_table.blocks)
# Compute the number of full blocks to touch and add it to the
# existing count of blocks to touch.
num_blocks_touched += self.block_allocator.get_num_full_blocks_touched(
blocks, device=device)
watermark_blocks = 0
if device == Device.GPU:
watermark_blocks = self.watermark_blocks
if self.block_allocator.get_num_total_blocks(
device) < num_blocks_touched:
return AllocStatus.NEVER
elif self.block_allocator.get_num_free_blocks(
device) - num_blocks_touched >= watermark_blocks:
return AllocStatus.OK
else:
return AllocStatus.LATER
def get_num_cached_tokens(self, seq: Sequence) -> int:
"""Get the number of tokens in blocks that are already computed and
cached in the block manager for the sequence.
"""
return self._computed_blocks_tracker.get_num_cached_tokens(seq)

View File

View File

@@ -0,0 +1,23 @@
# SPDX-License-Identifier: Apache-2.0
import torch
import torch.distributed
def tensor_model_parallel_all_reduce_with_odsp(input_: torch.Tensor) -> torch.Tensor:
"""All-reduce the input tensor across model parallel group."""
from vllm.distributed import get_tp_group
try:
total_bytes = input_.numel() * input_.element_size() * get_tp_group().world_size
# only support 4M now
if total_bytes < 4194304:
from torch_vacc.vacc import all_reduce
return all_reduce(input_,
get_tp_group().rank_in_group,
get_tp_group().world_size,
get_tp_group().group_id,
dev_info = get_tp_group().rank_device_infos)
except Exception as e:
print("all_reduce by DSP run Fail, now use vccl-ops", e, input_.shape, input_.dtype)
return get_tp_group().all_reduce(input_)

View File

@@ -0,0 +1,24 @@
import torch
import torch.distributed as dist
def all_gather_into_tensor(self, input_: torch.Tensor, dim: int = -1, output_tensor: torch.Tensor = None) -> torch.Tensor:
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
input_size = input_.size()
# NOTE: we have to use concat-style all-gather here,
# stack-style all-gather has compatibility issues with
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
output_size = (input_size[0] * self.world_size, ) + input_size[1:]
# Allocate output tensor.
# [N,] => [N*world_size], 1D Tensor
if output_tensor is None:
output_tensor = torch.empty(output_size,
dtype=input_.dtype,
device=input_.device)
# print("o tensor is:", output_tensor.shape, "i tensor is:", input_.shape, input_size)
# All-gather.
dist.all_gather_into_tensor(output_tensor,
input_,
group=self.device_group)
return output_tensor

View File

@@ -0,0 +1,485 @@
import torch
import torch.distributed
from vllm.platforms import current_platform
from vllm.utils import supports_custom_op
from collections import namedtuple
from typing import (Any, Dict, List, Optional, Tuple,
Union)
import torch
import torch.distributed
from torch.distributed import ProcessGroup
from vllm.utils import supports_custom_op
from vllm.distributed.parallel_state import TensorMetadata
# memory recycler
MEMORY_RECYCLER_KEY = ['previous_hidden_states']
def _split_tensor_dict_concat(
tensor_dict: Dict[str, Union[torch.Tensor, Any]]
) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
"""Split the tensor dictionary into two parts:
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
by its metadata.
2. A list of tensors.
"""
# all_tensor_list = ['input_tokens','input_positions', 'slot_mapping','seq_lens_tensor', 'context_lens_tensor','block_tables', 'query_start_loc','seq_start_loc', 'selected_token_indices']
metadata_list: List[Tuple[str, Any]] = []
tensor_list: List[torch.Tensor] = []
all_tensor = []
all_tensor_numel = 0
for key, value in tensor_dict.items():
if isinstance(value, torch.Tensor):
# Note: we cannot use `value.device` here,
# because it contains not only the device type but also the device
# index (e.g. "cuda:0"). We only need the device type.
# receiving side will set the device index.
device = value.device.type
if not value.is_cpu and value.numel() > 0:
value_bytes_tensor = value.view(torch.int8)
all_tensor.append(value_bytes_tensor.view([-1]))
all_tensor_numel += value_bytes_tensor.numel()
# tensor_list.append(value)
metadata_list.append(
(key, TensorMetadata(device, value.dtype, value.size())))
else:
metadata_list.append((key, value))
if len(all_tensor) != 0:
memory_recycler_dynamic_output = None
# 计算all_tensor的总大小
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler, DeepseekMTPMemoryRecycler
if isinstance(memory_recycler, DeepseekMTPMemoryRecycler):
memory_recycler_dynamic_output = memory_recycler.DYNAMIC_OUTPUT_BUFFER.view(torch.int8)[:all_tensor_numel]
if memory_recycler_dynamic_output is not None:
all_tensor = torch.concatenate(all_tensor, 0, out = memory_recycler_dynamic_output)
else:
all_tensor = torch.concatenate(all_tensor, 0)
tensor_list.append(all_tensor)
metadata_list.append(("all_tensor", TensorMetadata(all_tensor.device.type, all_tensor.dtype, all_tensor.size())))
return metadata_list, tensor_list
def all_gather_to_rank0(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
world_size = self.world_size
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
# For TPUs, use TPU communicator.
tpu_comm = self.tpu_communicator
if tpu_comm is not None and not tpu_comm.disabled:
return tpu_comm.all_gather(input_, dim)
# For HPUs, use HPU communicator.
hpu_comm = self.hpu_communicator
if hpu_comm is not None and not hpu_comm.disabled:
return hpu_comm.all_gather(input_, dim)
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
input_size = input_.size()
# NOTE: we have to use concat-style all-gather here,
# stack-style all-gather has compatibility issues with
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
output_size = (input_size[0] * world_size, ) + input_size[1:]
try:
total_bytes = input_.numel() * input_.element_size() * world_size
# only support 4M now
if total_bytes < 4194304:
from torch_vacc.vacc.custom_ops import all_gather
output_tensor = all_gather(input_, self.rank_in_group, self.world_size, self.group_id,
dev_info = self.rank_device_infos)
if self.rank_in_group != 0:
output_tensor = None
else:
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] +
(world_size *
input_size[dim], ) +
input_size[dim + 1:])
return output_tensor
except Exception as e:
print("all_gather by DSP run Fail, now use vccl-ops", e, input_.shape, input_.dtype)
# Allocate output tensor.
output_tensor = torch.empty(output_size,
dtype=input_.dtype,
device=input_.device)
# All-gather.
torch.distributed.all_gather_into_tensor(output_tensor,
input_,
group=self.device_group)
if self.rank_in_group != 0:
output_tensor = None
else:
# Reshape
output_tensor = output_tensor.reshape((world_size, ) + input_size)
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] +
(world_size *
input_size[dim], ) +
input_size[dim + 1:])
return output_tensor
def generate_group_id(self, group_id):
self.group_id = group_id
def generate_rank_device_infos(self):
import numpy as np
import os
# encoder rank_dev_list
def combine_arrays(a, b):
a = np.asarray(a, dtype=np.uint32)
b = np.asarray(b, dtype=np.uint32)
if len(a) != len(b):
raise ValueError("两个数组的长度必须一致。")
a_shifted = np.left_shift(a, 16)
combined = np.bitwise_or(a_shifted, b)
return combined.tolist()
# decoder rank_dev_list
def uncombine_array(array):
array = np.asarray(array, dtype=np.uint32)
o_0 = array >> 16
o_1 = array << 16 >> 16
return o_0, o_1
physical_devices = self.ranks
visible_devices = os.getenv('VACC_VISIBLE_DEVICES')
if visible_devices is not None:
device_list = visible_devices.split(',')
device_count = len(device_list)
assert device_count >= len(self.ranks), f'VACC_VISIBLE_DEVICES:{device_count} is less than ranks:{len(self.ranks)}, please designate more devices'
physical_devices = [int(device_list[i]) for i in self.ranks]
# print("[vccl] logic_devices:physical_devices ", self.ranks, physical_devices)
logic_ranks = [self.ranks.index(rank) for rank in self.ranks]
self.rank_device_infos = combine_arrays(logic_ranks, physical_devices)
def get_bitwidth(dtype):
if dtype.is_floating_point:
return torch.finfo(dtype).bits
else:
return torch.iinfo(dtype).bits
class GroupCoordinator:
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
"""
User-facing all-reduce function before we actually call the
all-reduce operation.
We need this because Dynamo does not support passing an arbitrary
object (`self` in this case) to a custom op. We need to pass the
group name as a string, and then look up the group coordinator from
the group name, dispatch the all-reduce operation to the group
coordinator.
In addition, PyTorch custom ops do not support mutation or returning
a new tensor in the same op. So we need to figure out if the op is
in-place or out-of-place ahead of time.
"""
# Bypass the function if we are using only 1 GPU.
if self.world_size == 1:
return input_
if input_.is_cpu:
import intel_extension_for_pytorch as ipex
ipex.distributed.all_reduce(input_, group=self.device_group)
return input_
# vacc impl
# s0, s1 = input_.shape
# output_tensor = torch.empty([32, s0, s1],
# dtype=input_.dtype,
# device=input_.device)
# torch.distributed.all_gather_into_tensor(output_tensor,
# input_,
# group=self.device_group)
# input_ = output_tensor.sum(dim=0)
torch.distributed.all_reduce(input_, group=self.device_group)
return input_
def broadcast_tensor_dict(
self,
tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None,
src: int = 0,
group: Optional[ProcessGroup] = None,
metadata_group: Optional[ProcessGroup] = None
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
"""Broadcast the input tensor dictionary.
NOTE: `src` is the local rank of the source rank.
"""
# all_tensor_list = ['input_tokens','input_positions', 'slot_mapping','seq_lens_tensor', 'context_lens_tensor','block_tables', 'query_start_loc','seq_start_loc', 'selected_token_indices']
# Bypass the function if we are using only 1 GPU.
if (not torch.distributed.is_initialized() or self.world_size == 1):
return tensor_dict
group = self.device_group
metadata_group = self.cpu_group
assert src < self.world_size, f"Invalid src rank ({src})"
rank_in_group = self.rank_in_group
if rank_in_group == src:
metadata_list: List[Tuple[Any, Any]] = []
assert isinstance(
tensor_dict,
dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
# metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
metadata_list, tensor_list = _split_tensor_dict_concat(tensor_dict)
# `metadata_list` lives in CPU memory.
# `broadcast_object_list` has serialization & deserialization,
# all happening on CPU. Therefore, we can use the CPU group.
# metadata_list 包含 key, value value是metadata 只有shape,没有数据
self.broadcast_object(metadata_list, src=src)
async_handles = []
for tensor in tensor_list:
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
continue
if tensor.is_cpu:
# use metadata_group for CPU tensors
handle = torch.distributed.broadcast(tensor,
src=self.ranks[src],
group=metadata_group,
async_op=True)
async_handles.append(handle)
else:
# use group for GPU tensors
total_bytes = tensor.numel() * tensor.element_size()
use_dist = True
# only support 4M now
if total_bytes < 4194304:
try:
from torch_vacc.vacc.custom_ops import broadcast
#print("send tensor is:", tensor.shape, tensor.dtype, self.rank)
broadcast(tensor, self.rank_in_group, self.world_size, root_rank=0, group_id=self.group_id,
dev_info = self.rank_device_infos)
use_dist = False
except Exception as e:
print("odsp broadcast run fail, now using distributed:", e)
if use_dist:
handle = torch.distributed.broadcast(tensor,
src=self.ranks[src],
group=group,
async_op=True)
async_handles.append(handle)
for async_handle in async_handles:
async_handle.wait()
else:
# other rank
metadata_list = self.broadcast_object(None, src=src)
tensor_dict = {}
async_handles = []
tensor_size = [] # list of [key, shape] for split all_tensor
dataType_list = []
for key, value in metadata_list:
# if rank_in_group == 1:
# print('rank1 k v ', key, value)
if isinstance(value, TensorMetadata):
tensor = None
# 固定为int8
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler, DeepseekMTPMemoryRecycler
if isinstance(memory_recycler, DeepseekMTPMemoryRecycler) and value.dtype == torch.int8:
tensor = memory_recycler.DYNAMIC_OUTPUT_BUFFER.view(value.dtype)[:value.size.numel()].view(value.size)
if tensor is None:
tensor = torch.empty(value.size,
dtype=value.dtype,
device=value.device)
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
tensor_dict[key] = tensor
continue
if tensor.is_cpu:
# use metadata_group for CPU tensors
handle = torch.distributed.broadcast(
tensor,
src=self.ranks[src],
group=metadata_group,
async_op=True)
async_handles.append(handle)
tensor_dict[key] = tensor
else:
# use group for GPU tensors
if key == "all_tensor":
total_bytes = tensor.numel() * tensor.element_size()
use_dist = True
# only support 4M now
if total_bytes < 4194304:
try:
from torch_vacc.vacc.custom_ops import broadcast
tensor = broadcast(tensor, self.rank_in_group, self.world_size, root_rank=0, group_id=self.group_id,
dev_info = self.rank_device_infos)
use_dist = False
except Exception as e:
print("dsp brocast run fail, now using distributed:", e)
if use_dist:
handle = torch.distributed.broadcast(
tensor, #拼接的tensor
src=self.ranks[src],
group=group,
async_op=False)
# 按 key shape对, 拆分 all_tensor, 存入tensor_dict
start = 0
idx = 0
for ki_vi in tensor_size:
ki, vi = ki_vi
length = vi.numel() * int(get_bitwidth(dataType_list[idx]) / 8)
if ki in MEMORY_RECYCLER_KEY:
tensor_dict[ki] = tensor[start:start+length].view(dataType_list[idx]).view(vi)
else:
value_tensor = torch.empty(vi,
dtype=dataType_list[idx],
device=value.device)
recv_tensor = tensor[start:start+length].view(dataType_list[idx]).view(vi)
tensor_dict[ki] = value_tensor.copy_(recv_tensor)
start += length
idx += 1
else:
dataType_list.append(value.dtype)
tensor_size.append([key, value.size])
else:
tensor_dict[key] = value
for async_handle in async_handles:
async_handle.wait()
return tensor_dict
def all_gather(self, input_: torch.Tensor, dim: int = -1, output_: torch.Tensor = None) -> torch.Tensor:
world_size = self.world_size
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
if self.use_custom_op_call:
return torch.ops.vllm.all_gather(input_,
dim,
world_size,
group_name=self.unique_name)
else:
# 启用输出复用版的 all_gather
if output_ is not None:
return self.device_communicator.all_gather_into_tensor(input_, dim, output_)
return self._all_gather_out_place(input_, dim)
def recv_tensor_dict(
self,
src: Optional[int] = None,
all_gather_group: Optional["GroupCoordinator"] = None,
) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
"""Recv the input tensor dictionary.
NOTE: `src` is the local rank of the source rank.
"""
# Bypass the function if we are using only 1 GPU.
if not torch.distributed.is_initialized() or self.world_size == 1:
return None
all_gather_size = (1 if all_gather_group is None else
all_gather_group.world_size)
all_gather_rank = (0 if all_gather_group is None else
all_gather_group.rank_in_group)
group = self.device_group
metadata_group = self.cpu_group
if src is None:
src = (self.rank_in_group - 1) % self.world_size
assert src < self.world_size, f"Invalid src rank ({src})"
recv_metadata_list = self.recv_object(src=src)
tensor_dict: dict[str, Any] = {}
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import alloc_pipeline_parallel_recycler_buffer
memory_recycler_list = ["hidden_states", "residual"]
for key, value in recv_metadata_list:
if isinstance(value, TensorMetadata):
# 判断是否需要内存复用
# 1. key在[hiddens, residual]中,说明为PP造成的
# 2. 可以根据key从 memory_recycling 模块中申请到tensor
use_create_recycler_tensor = False
tensor = None
value_tensor = None # 用于接收all_gather总数据
if key in memory_recycler_list:
tensor = alloc_pipeline_parallel_recycler_buffer(value.size, value.dtype, key)
if tensor is not None:
use_create_recycler_tensor = True
if not use_create_recycler_tensor:
tensor = torch.empty(value.size,
dtype=value.dtype,
device=value.device)
value_tensor = tensor
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
tensor_dict[key] = tensor
continue
# send-allgather: send only a slice, then do allgather.
use_all_gather = (all_gather_group is not None
and tensor.numel() % all_gather_size == 0)
if use_all_gather:
orig_shape = tensor.shape
# 内存复用无需reshape, view即可
if use_create_recycler_tensor:
tensor = tensor.view(all_gather_size,
-1)[all_gather_rank].contiguous()
else:
tensor = tensor.reshape(all_gather_size,
-1)[all_gather_rank]
if tensor.is_cpu:
# use metadata_group for CPU tensors
torch.distributed.recv(tensor,
src=self.ranks[src],
group=metadata_group)
else:
# use group for GPU tensors
torch.distributed.recv(tensor,
src=self.ranks[src],
group=group)
if use_all_gather:
# do the allgather
if use_create_recycler_tensor:
tensor = all_gather_group.all_gather( # type: ignore
tensor, dim=0, output_ = value_tensor)
tensor = tensor.view(orig_shape)
else:
tensor = all_gather_group.all_gather( # type: ignore
tensor, dim=0)
tensor = tensor.reshape(orig_shape)
tensor_dict[key] = tensor
else:
tensor_dict[key] = value
return tensor_dict

View File

View File

@@ -0,0 +1,158 @@
import argparse
import copy
import dataclasses
import functools
import json
import sys
import threading
import warnings
from dataclasses import MISSING, dataclass, fields, is_dataclass
from itertools import permutations
from typing import (Annotated, Any, Callable, Dict, List, Literal, Optional,
Type, TypeVar, Union, cast, get_args, get_origin)
import regex as re
import torch
from pydantic import TypeAdapter, ValidationError
from typing_extensions import TypeIs, deprecated
import vllm.envs as envs
from vllm.config import ModelConfig
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.plugins import load_general_plugins
from vllm.reasoning import ReasoningParserManager
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
from vllm.transformers_utils.utils import check_gguf_file
from vllm.usage.usage_lib import UsageContext
from vllm.utils import (FlexibleArgumentParser, GiB_bytes, get_ip,
is_in_ray_actor)
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import CpuArchEnum, current_platform
logger = init_logger(__name__)
def _set_default_args(self, usage_context: UsageContext,
model_config: ModelConfig) -> None:
"""Set Default Arguments for V1 Engine."""
# V1 always uses chunked prefills and prefix caching
# for non-pooling tasks.
# For pooling tasks the default is False
self.enable_chunked_prefill = False
self.enable_prefix_caching = False
if model_config.runner_type != "pooling":
# TODO: When prefix caching supports prompt embeds inputs, this
# check can be removed.
if (self.enable_prompt_embeds
and self.enable_prefix_caching is not False):
logger.warning(
"--enable-prompt-embeds and --enable-prefix-caching "
"are not supported together in V1. Prefix caching has "
"been disabled.")
# V1 should use the new scheduler by default.
# Swap it only if this arg is set to the original V0 default
if self.scheduler_cls == EngineArgs.scheduler_cls:
self.scheduler_cls = "vllm.v1.core.sched.scheduler.Scheduler"
# When no user override, set the default values based on the usage
# context.
# Use different default values for different hardware.
# Try to query the device name on the current platform. If it fails,
# it may be because the platform that imports vLLM is not the same
# as the platform that vLLM is running on (e.g. the case of scaling
# vLLM with Ray) and has no GPUs. In this case we use the default
# values for non-H100/H200 GPUs.
try:
device_memory = current_platform.get_device_total_memory()
device_name = current_platform.get_device_name().lower()
except Exception:
# This is only used to set default_max_num_batched_tokens
device_memory = 0
# NOTE(Kuntai): Setting large `max_num_batched_tokens` for A100 reduces
# throughput, see PR #17885 for more details.
# So here we do an extra device name check to prevent such regression.
from vllm.usage.usage_lib import UsageContext
if device_memory >= 70 * GiB_bytes and "a100" not in device_name:
# For GPUs like H100 and MI300x, use larger default values.
default_max_num_batched_tokens = {
UsageContext.LLM_CLASS: 16384,
UsageContext.OPENAI_API_SERVER: 8192,
}
default_max_num_seqs = {
UsageContext.LLM_CLASS: 1024,
UsageContext.OPENAI_API_SERVER: 1024,
}
else:
# TODO(woosuk): Tune the default values for other hardware.
default_max_num_batched_tokens = {
UsageContext.LLM_CLASS: 8192,
UsageContext.OPENAI_API_SERVER: 2048,
}
default_max_num_seqs = {
UsageContext.LLM_CLASS: 4,
UsageContext.OPENAI_API_SERVER: 4,
}
# tpu specific default values.
if current_platform.is_tpu():
default_max_num_batched_tokens_tpu = {
UsageContext.LLM_CLASS: {
'V6E': 2048,
'V5E': 1024,
'V5P': 512,
},
UsageContext.OPENAI_API_SERVER: {
'V6E': 1024,
'V5E': 512,
'V5P': 256,
}
}
# cpu specific default values.
if current_platform.is_cpu():
world_size = self.pipeline_parallel_size * self.tensor_parallel_size
default_max_num_batched_tokens = {
UsageContext.LLM_CLASS: 4096 * world_size,
UsageContext.OPENAI_API_SERVER: 2048 * world_size,
}
default_max_num_seqs = {
UsageContext.LLM_CLASS: 256 * world_size,
UsageContext.OPENAI_API_SERVER: 128 * world_size,
}
use_context_value = usage_context.value if usage_context else None
if (self.max_num_batched_tokens is None
and usage_context in default_max_num_batched_tokens):
if current_platform.is_tpu():
chip_name = current_platform.get_device_name()
if chip_name in default_max_num_batched_tokens_tpu[
usage_context]:
self.max_num_batched_tokens = \
default_max_num_batched_tokens_tpu[
usage_context][chip_name]
else:
self.max_num_batched_tokens = \
default_max_num_batched_tokens[usage_context]
else:
if not self.enable_chunked_prefill:
self.max_num_batched_tokens = model_config.max_model_len
else:
self.max_num_batched_tokens = \
default_max_num_batched_tokens[usage_context]
logger.debug(
"Setting max_num_batched_tokens to %d for %s usage context.",
self.max_num_batched_tokens, use_context_value)
if (self.max_num_seqs is None
and usage_context in default_max_num_seqs):
self.max_num_seqs = min(default_max_num_seqs[usage_context],
self.max_num_batched_tokens or sys.maxsize)
logger.debug("Setting max_num_seqs to %d for %s usage context.",
self.max_num_seqs, use_context_value)

View File

@@ -0,0 +1,49 @@
from vllm.engine.arg_utils import EngineArgs
from vllm.usage.usage_lib import UsageContext
from typing import Dict, Optional
from vllm.engine.metrics_types import StatLoggerBase
class LLMEngine:
@classmethod
def from_engine_args(
cls,
engine_args: EngineArgs,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
vllm_config = engine_args.create_engine_config(usage_context)
#patch to prevent num_speculative_tokens > 1
speculative_mode = hasattr(vllm_config, 'speculative_config')
if speculative_mode and \
hasattr(vllm_config.speculative_config, 'num_speculative_tokens') and \
vllm_config.speculative_config.num_speculative_tokens != 1:
raise ValueError(f'run_mp_engine: only support num_speculative_tokens == 1, but get {vllm_config.speculative_config.num_speculative_tokens}')
default_model_infos = "default"
if speculative_mode:
if hasattr(vllm_config.speculative_config, 'method'):
default_model_infos = vllm_config.speculative_config.method
from vllm_vacc.vllm.config_manager import vllm_vacc_config_manager
vllm_vacc_config_manager().update_model_infos(default_model_infos)
import vllm.envs as envs
engine_cls = None
if envs.VLLM_USE_V1:
from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
engine_cls = V1LLMEngine
else:
from vllm.engine.llm_engine import LLMEngine as DefaultEngine
engine_cls = DefaultEngine
assert engine_cls is not None, f"LLMEngine is empty: {engine_cls}"
return engine_cls.from_vllm_config(
vllm_config=vllm_config,
usage_context=usage_context,
stat_loggers=stat_loggers,
disable_log_stats=engine_args.disable_log_stats,
)

View File

@@ -0,0 +1,69 @@
from vllm.engine.metrics_types import (StatLoggerBase, Stats)
import vllm_vacc.vllm.model_executor.models.vars as global_vars
class LoggingStatLogger(StatLoggerBase):
"""LoggingStatLogger is used in LLMEngine to log to Stdout."""
def log(self, stats: Stats) -> None:
from vllm.engine.metrics import local_interval_elapsed, get_throughput, logger
"""Called by LLMEngine.
Logs to Stdout every self.local_interval seconds."""
# Save tracked stats for token counters.
self.num_prompt_tokens.append(stats.num_prompt_tokens_iter)
self.num_generation_tokens.append(stats.num_generation_tokens_iter)
# Update spec decode metrics
self.maybe_update_spec_decode_metrics(stats)
# Log locally every local_interval seconds.
if local_interval_elapsed(stats.now, self.last_local_log,
self.local_interval):
# Compute summary metrics for tracked stats (and log them
# to promethus if applicable).
prompt_throughput = get_throughput(self.num_prompt_tokens,
now=stats.now,
last_log=self.last_local_log)
generation_throughput = get_throughput(
self.num_generation_tokens,
now=stats.now,
last_log=self.last_local_log)
log_fn = logger.info
if not any((prompt_throughput, generation_throughput,
self.last_prompt_throughput,
self.last_generation_throughput)):
# Avoid log noise on an idle production system
log_fn = logger.debug
log_fn(
"Avg prompt throughput: %.1f tokens/s, "
"Avg generation throughput: %.1f tokens/s, "
"Running: %d reqs, Swapped: %d reqs, "
"Pending: %d reqs, GPU KV cache usage: %.1f%%, "
"CPU KV cache usage: %.1f%%., "
"Do sequences length: %s",
prompt_throughput,
generation_throughput,
stats.num_running_sys,
stats.num_swapped_sys,
stats.num_waiting_sys,
stats.gpu_cache_usage_sys * 100,
stats.cpu_cache_usage_sys * 100,
str(global_vars.DO_SEQ_LENS)
)
if (stats.cpu_prefix_cache_hit_rate >= 0
or stats.gpu_prefix_cache_hit_rate >= 0):
log_fn(
"Prefix cache hit rate: GPU: %.2f%%, CPU: %.2f%%",
stats.gpu_prefix_cache_hit_rate * 100,
stats.cpu_prefix_cache_hit_rate * 100,
)
if self.spec_decode_metrics is not None:
logger.debug(
self._format_spec_decode_metrics_str(
self.spec_decode_metrics))
self._reset(stats, prompt_throughput, generation_throughput)

View File

@@ -0,0 +1,103 @@
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR,
RPCError,
RPCProcessRequest,
RPCAbortRequest)
from vllm.config import VllmConfig
import signal
from vllm.logger import init_logger
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.usage.usage_lib import UsageContext
logger = init_logger(__name__)
class MQLLMEngine:
def _handle_process_request(self, request: RPCProcessRequest):
"""Handle RPCProcessRequest by adding it to the LLMEngine."""
request_id = request.request_id
if self._errored_with is not None:
rpc_err = RPCError(request_id=request_id,
is_engine_errored=True,
exception=ENGINE_DEAD_ERROR(self._errored_with))
self._send_outputs(rpc_err)
try:
self.engine.add_request(
request_id=request_id,
prompt=request.prompt,
params=request.params,
lora_request=request.lora_request,
trace_headers=request.trace_headers,
prompt_adapter_request=request.prompt_adapter_request,
priority=request.priority)
if self.log_requests:
from vllm.engine.multiprocessing.engine import logger
if request.prompt.get('prompt_token_ids') is not None:
# logger.info("Added request: %s, %s, prompt length: %s", request.request_id, request.prompt['prompt_token_ids'], len(request.prompt['prompt_token_ids']))
logger.info("Added request: %s, prompt length: %s", request.request_id, len(request.prompt['prompt_token_ids']))
else:
logger.info("Added request %s.", request.request_id)
except Exception as e:
# We do not set self._errored = True here, since the error
# is due to an issue adding this request to the engine,
# rather than an issue with the engine itself.
is_errored = self._errored_with is not None
rpc_err = RPCError(request_id=request_id,
is_engine_errored=is_errored,
exception=e)
self._send_outputs(rpc_err)
# Remove request from the engine.
self.engine.abort_request(request_id)
def _handle_abort_request(self, request: RPCAbortRequest):
self.engine.abort_request(request.request_id)
if self.log_requests:
from vllm.engine.multiprocessing.engine import logger
import vllm_vacc.vllm.model_executor.models.vars as global_vars
logger.info("Aborted request: %s, prompt length: %s", request.request_id, global_vars.DO_SEQ_LENS)
def run_mp_engine(vllm_config: VllmConfig, usage_context: UsageContext,
ipc_path: str, disable_log_stats: bool,
disable_log_requests: bool, engine_alive):
#patch to prevent num_speculative_tokens > 1
speculative_mode = hasattr(vllm_config, 'speculative_config')
if speculative_mode and \
hasattr(vllm_config.speculative_config, 'num_speculative_tokens') and \
vllm_config.speculative_config.num_speculative_tokens != 1:
raise ValueError(f'run_mp_engine: only support num_speculative_tokens == 1, but get {vllm_config.speculative_config.num_speculative_tokens}')
default_model_infos = "default"
if speculative_mode:
if hasattr(vllm_config.speculative_config, 'method'):
default_model_infos = vllm_config.speculative_config.method
from vllm_vacc.vllm.config_manager import vllm_vacc_config_manager
vllm_vacc_config_manager().update_model_infos(default_model_infos)
try:
# Ensure we can serialize transformer config before spawning
maybe_register_config_serialize_by_value()
from vllm.engine.multiprocessing.engine import MQLLMEngine,signal_handler
engine = MQLLMEngine.from_vllm_config(
vllm_config=vllm_config,
usage_context=usage_context,
disable_log_stats=disable_log_stats,
disable_log_requests=disable_log_requests,
ipc_path=ipc_path)
signal.signal(signal.SIGTERM, signal_handler)
engine.start()
except BaseException as e:
logger.exception(e)
engine_alive.value = False
raise e

View File

View File

@@ -0,0 +1,102 @@
import itertools
import warnings
from collections.abc import Sequence
from contextlib import contextmanager
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union,
cast, overload)
import cloudpickle
import torch.nn as nn
from pydantic import ValidationError
from tqdm.auto import tqdm
from typing_extensions import TypeVar, deprecated
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.inputs import PromptType
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import (RequestOutputKind, SamplingParams)
from vllm.platforms import current_platform
logger = init_logger(__name__)
_R = TypeVar("_R", default=Any)
class LLM:
EPRECATE_LEGACY: ClassVar[bool] = True
def _validate_and_add_requests(
self,
prompts: Union[PromptType, Sequence[PromptType]],
params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
Sequence[PoolingParams]],
*,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
priority: Optional[list[int]] = None,
) -> None:
if isinstance(prompts, (str, dict)):
# Convert a single prompt to a list.
prompts = [prompts]
num_requests = len(prompts)
if isinstance(params, Sequence) and len(params) != num_requests:
raise ValueError("The lengths of prompts and params "
"must be the same.")
if isinstance(lora_request,
Sequence) and len(lora_request) != num_requests:
raise ValueError("The lengths of prompts and lora_request "
"must be the same.")
for sp in params if isinstance(params, Sequence) else (params, ):
if isinstance(sp, SamplingParams):
# We only care about the final output
sp.output_kind = RequestOutputKind.FINAL_ONLY
# Add requests to the engine.
it = prompts
if use_tqdm:
tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
it = tqdm_func(it, desc="Adding requests")
if (hasattr(current_platform, 'supports_v1') and current_platform.supports_v1(current_platform)):
batch_items = []
model_config = self.llm_engine.model_config
for i, prompt in enumerate(it):
request_id = str(next(self.request_counter))
# print("requset_id===========", request_id)
param = params[i] if isinstance(params, Sequence) else params
tokenization_kwargs: dict[str, Any] = {}
_validate_truncation_size(model_config.max_model_len,
param.truncate_prompt_tokens,
tokenization_kwargs)
batch_items.append((
request_id,
prompt,
params[i] if isinstance(params, Sequence) else params,
None, # arrival_time不用的话传 None
(lora_request[i] if isinstance(lora_request, Sequence)
else lora_request),
tokenization_kwargs,
None, # trace_headers如无 APM/TracingNone
(priority[i] if priority else 0),
))
# 一次性下发给 EngineCore走 ADD_BULK
self.llm_engine.add_requests(batch_items)
else:
for i, prompt in enumerate(it):
self._add_request(
prompt,
params[i] if isinstance(params, Sequence) else params,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request[i] if isinstance(
lora_request, Sequence) else lora_request,
priority=priority[i] if priority else 0,
)

View File

@@ -0,0 +1,345 @@
import asyncio
import time
from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import Sequence as GenericSequence
from typing import Optional, Union, cast
import jinja2
from fastapi import Request
from typing_extensions import assert_never
from concurrent.futures.thread import ThreadPoolExecutor
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse,
ErrorResponse,
RequestResponseMetadata,
UsageInfo)
from vllm.entrypoints.openai.serving_engine import (
EmbedsPrompt as ServingEngineEmbedsPrompt)
from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
TextTokensPrompt,
clamp_prompt_logprobs,
is_text_tokens_prompt)
# yapf: enable
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.utils import get_max_tokens
from vllm.inputs.data import (EmbedsPrompt, TokensPrompt, is_embeds_prompt,
is_tokens_prompt)
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import merge_async_iterators
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_completion import logger
from vllm.utils import (is_list_of, make_async, merge_async_iterators,
random_uuid)
from vllm.utils import (AsyncMicrobatchTokenizer, is_list_of,
merge_async_iterators, random_uuid)
from vllm_vacc.vllm.model_executor.models.vars import LLM_MAX_PREFILL_SEQ_LEN
class OpenAIServingCompletion(OpenAIServing):
def __init__(
self,
engine_client: EngineClient,
model_config: ModelConfig,
models: OpenAIServingModels,
*,
request_logger: Optional[RequestLogger],
return_tokens_as_token_ids: bool = False,
enable_prompt_tokens_details: bool = False,
enable_force_include_usage: bool = False,
enable_strict_batch_barrier: bool = True,
log_error_stack: bool = False,
):
self.engine_client = engine_client
self.model_config = model_config
self.max_model_len = model_config.max_model_len
self.models = models
self.request_logger = request_logger
self.return_tokens_as_token_ids = return_tokens_as_token_ids
self.enable_force_include_usage = enable_force_include_usage
self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
self._async_tokenizer_pool: dict[AnyTokenizer,
AsyncMicrobatchTokenizer] = {}
self.log_error_stack = log_error_stack
self.enable_prompt_tokens_details = enable_prompt_tokens_details
self.default_sampling_params = (
self.model_config.get_diff_sampling_param())
if self.default_sampling_params:
source = self.model_config.generation_config
source = "model" if source == "auto" else source
logger.info("Using default completion sampling params from %s: %s",
source, self.default_sampling_params)
self.enable_strict_batch_barrier = enable_strict_batch_barrier
async def create_completion(
self,
request: CompletionRequest,
raw_request: Optional[Request] = None,
) -> Union[AsyncGenerator[str, None], CompletionResponse, ErrorResponse]:
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/completions/create
for the API specification. This API mimics the OpenAI Completion API.
NOTE: Currently we do not support the following feature:
- suffix (the language models we currently support do not support
suffix)
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
# If the engine is dead, raise the engine's DEAD_ERROR.
# This is required for the streaming case, where we return a
# success status before we actually start generating text :).
if self.engine_client.errored:
raise self.engine_client.dead_error
# Return error for unsupported features.
if request.suffix is not None:
return self.create_error_response(
"suffix is not currently supported")
if request.echo and request.prompt_embeds is not None:
return self.create_error_response(
"Echo is unsupported with prompt embeds.")
if (request.prompt_logprobs is not None
and request.prompt_embeds is not None):
return self.create_error_response(
"prompt_logprobs is not compatible with prompt embeds.")
request_id = (
f"cmpl-"
f"{self._base_request_id(raw_request, request.request_id)}")
created_time = int(time.time())
request_metadata = RequestResponseMetadata(request_id=request_id)
if raw_request:
raw_request.state.request_metadata = request_metadata
try:
lora_request = self._maybe_get_adapters(request)
if self.model_config.skip_tokenizer_init:
tokenizer = None
else:
tokenizer = await self.engine_client.get_tokenizer()
renderer = self._get_renderer(tokenizer)
engine_prompts = await renderer.render_prompt_and_embeds(
prompt_or_prompts=request.prompt,
prompt_embeds=request.prompt_embeds,
deepstack_input_embeds=request.deepstack_input_embeds if hasattr(request, 'deepstack_input_embeds') else None,
config=self._build_render_config(request),
)
except ValueError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
except TypeError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
except RuntimeError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
except jinja2.TemplateError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
# Schedule the request and get the result generator.
generators: list[AsyncGenerator[RequestOutput, None]] = []
try:
total_num_prompts = len(engine_prompts)
for i, engine_prompt in enumerate(engine_prompts):
sampling_params: Union[SamplingParams, BeamSearchParams]
# Mypy does not infer that engine_prompt will have only one of
# "prompt_token_ids" or "prompt_embeds" defined, and both of
# these as Union[object, the expected type], where it infers
# object if engine_prompt is a subclass of one of the
# typeddicts that defines both keys. Worse, because of
# https://github.com/python/mypy/issues/8586, mypy does not
# infer the type of engine_prompt correctly because of the
# enumerate. So we need an unnecessary cast here.
engine_prompt = cast(Union[EmbedsPrompt, TokensPrompt],
engine_prompt)
if is_embeds_prompt(engine_prompt):
input_length = len(engine_prompt["prompt_embeds"])
elif is_tokens_prompt(engine_prompt):
input_length = len(engine_prompt["prompt_token_ids"])
if input_length > LLM_MAX_PREFILL_SEQ_LEN:
raise ValueError(
f"This model's maximum input seq length limit is "
f"{LLM_MAX_PREFILL_SEQ_LEN} tokens. However, you requested "
f"({input_length} in the input messages, "
f"Please reduce the length of the input messages.")
else:
assert_never(engine_prompt)
if self.default_sampling_params is None:
self.default_sampling_params = {}
max_tokens = get_max_tokens(
max_model_len=self.max_model_len,
request=request,
input_length=input_length,
default_sampling_params=self.default_sampling_params,
)
if request.use_beam_search:
sampling_params = request.to_beam_search_params(
max_tokens, self.default_sampling_params)
else:
sampling_params = request.to_sampling_params(
max_tokens,
self.model_config.logits_processor_pattern,
self.default_sampling_params,
)
# Inject strict batch barrier metadata so this batch is held
# until all items are ready, then scheduled together.
if (self.enable_strict_batch_barrier
and total_num_prompts > 1
and isinstance(sampling_params, SamplingParams)):
if sampling_params.extra_args is None:
sampling_params.extra_args = {}
sampling_params.extra_args.setdefault("barrier_group_id",
request_id)
sampling_params.extra_args.setdefault("barrier_group_size",
total_num_prompts)
request_id_item = f"{request_id}-{i}"
self._log_inputs(
request_id_item,
engine_prompt,
params=sampling_params,
lora_request=lora_request,
)
trace_headers = (None if raw_request is None else await
self._get_trace_headers(raw_request.headers))
# Mypy inconsistently requires this second cast in different
# environments. It shouldn't be necessary (redundant from above)
# but pre-commit in CI fails without it.
engine_prompt = cast(Union[EmbedsPrompt, TokensPrompt],
engine_prompt)
if isinstance(sampling_params, BeamSearchParams):
generator = self.engine_client.beam_search(
prompt=engine_prompt,
request_id=request_id,
params=sampling_params,
lora_request=lora_request,
)
else:
generator = self.engine_client.generate(
engine_prompt,
sampling_params,
request_id_item,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
)
generators.append(generator)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
logger.error(e)
return self.create_error_response(str(e))
result_generator = merge_async_iterators(*generators)
model_name = self.models.model_name(lora_request)
num_prompts = len(engine_prompts)
# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. Noting that best_of is only supported in V0. In addition,
# we do not stream the results when use beam search.
stream = (request.stream
and (request.best_of is None or request.n == request.best_of)
and not request.use_beam_search)
# Streaming response
if stream:
return self.completion_stream_generator(
request,
engine_prompts,
result_generator,
request_id,
created_time,
model_name,
num_prompts=num_prompts,
tokenizer=tokenizer,
request_metadata=request_metadata,
enable_force_include_usage=self.enable_force_include_usage,
)
# Non-streaming response
final_res_batch: list[Optional[RequestOutput]] = [None] * num_prompts
try:
async for i, res in result_generator:
final_res_batch[i] = res
for i, final_res in enumerate(final_res_batch):
assert final_res is not None
# The output should contain the input text
# We did not pass it into vLLM engine to avoid being redundant
# with the inputs token IDs
if final_res.prompt is None:
engine_prompt = engine_prompts[i]
final_res.prompt = None if is_embeds_prompt(
engine_prompt) else engine_prompt.get("prompt")
final_res_batch_checked = cast(list[RequestOutput],
final_res_batch)
response = self.request_output_to_completion_response(
final_res_batch_checked,
request,
request_id,
created_time,
model_name,
tokenizer,
request_metadata,
)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
# When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event.
if request.stream:
response_json = response.model_dump_json()
async def fake_stream_generator() -> AsyncGenerator[str, None]:
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"
return fake_stream_generator()
return response

View File

@@ -0,0 +1,191 @@
# SPDX-License-Identifier: Apache-2.0
import json
from concurrent.futures.thread import ThreadPoolExecutor
from http import HTTPStatus
from typing import (Any, Callable, Dict, Iterable, Iterator, List, Mapping,
Optional, Sequence, Tuple, TypedDict, Union)
from fastapi import Request
from pydantic import Field
from starlette.datastructures import Headers
from typing_extensions import Annotated
import torch
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
ChatTemplateContentFormatOption,
ConversationMessage,
apply_hf_chat_template,
apply_mistral_chat_template,
parse_chat_messages_futures,
resolve_chat_template_content_format)
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest,
DetokenizeRequest,
EmbeddingChatRequest,
EmbeddingCompletionRequest,
ErrorResponse, RerankRequest,
ScoreRequest,
TokenizeChatRequest,
TokenizeCompletionRequest)
# yapf: enable
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.entrypoints.openai.serving_engine import AnyRequest, TextTokensPrompt
# from vllm.model_executor.sampling_metadata import _SAMPLING_EPS
from vllm.v1.sample.sampler import _SAMPLING_EPS
import os
logger = init_logger(__name__)
from vllm_vacc.vllm.model_executor.models.vars import LLM_MAX_PREFILL_SEQ_LEN
from vllm_vacc.vllm.model_executor.models.vars import CUT_PREFILL_SEQ_LEN
class EmbedsPrompt(TypedDict):
prompt_embeds: torch.Tensor
deepstack_input_embeds: Optional[dict]
class OpenAIServing:
def _validate_input(
self,
request: AnyRequest,
input_ids: List[int],
input_text: str,
) -> TextTokensPrompt:
# clint 设置的参数, 如果没有设, 还会再从 generation_config.json 读取
if CUT_PREFILL_SEQ_LEN > 0 and CUT_PREFILL_SEQ_LEN < len(input_ids):
cut_before = CUT_PREFILL_SEQ_LEN // 2
cut_after = CUT_PREFILL_SEQ_LEN - cut_before
input_ids = input_ids[:cut_before] + input_ids[(-1)*cut_after:]
token_num = len(input_ids)
if not self.model_config.pooler_config:
if (request.repetition_penalty is not None and abs(request.repetition_penalty - 1.0) >= _SAMPLING_EPS):
raise ValueError(
f"unsupport penalty for sampler"
f"request.repetition_penalty: {request.repetition_penalty}; "
f"Please remove penalty parameter in client and try again."
)
if request.min_p is not None and request.min_p > _SAMPLING_EPS:
raise ValueError(f"unsupport min_p {request.min_p} for sampler")
if request.prompt_logprobs is not None:
raise ValueError(f"unsupport prompt_logprobs {request.prompt_logprobs} for sampler")
if request.min_p is not None and request.min_p > _SAMPLING_EPS:
raise ValueError(f"unsupport min_p {request.min_p} for sampler")
if request.prompt_logprobs is not None:
raise ValueError(f"unsupport prompt_logprobs {request.prompt_logprobs} for sampler")
# model_type = self.model_config.hf_config.model_type
# if model_type == "deepseek_v3":
if token_num > LLM_MAX_PREFILL_SEQ_LEN:
raise ValueError(
f"This model's maximum input seq length limit is "
f"{LLM_MAX_PREFILL_SEQ_LEN} tokens. However, you requested "
f"({token_num} in the input messages, "
f"Please reduce the length of the input messages.")
# Note: EmbeddingRequest and ScoreRequest doesn't have max_tokens
if isinstance(request,
(EmbeddingChatRequest, EmbeddingCompletionRequest,
ScoreRequest, RerankRequest)):
operation = "score" if isinstance(request, ScoreRequest) \
else "embedding generation"
if token_num > self.max_model_len:
raise ValueError(
f"This model's maximum context length is "
f"{self.max_model_len} tokens. However, you requested "
f"{token_num} tokens in the input for {operation}. "
f"Please reduce the length of the input.")
return TextTokensPrompt(prompt=input_text,
prompt_token_ids=input_ids)
# Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
# and does not require model context length validation
if isinstance(request, (TokenizeCompletionRequest, TokenizeChatRequest,
DetokenizeRequest)):
return TextTokensPrompt(prompt=input_text,
prompt_token_ids=input_ids)
# chat completion endpoint supports max_completion_tokens
if isinstance(request, ChatCompletionRequest):
# TODO(#9845): remove max_tokens when field dropped from OpenAI API
max_tokens = request.max_completion_tokens or request.max_tokens
else:
max_tokens = request.max_tokens
if max_tokens is None:
if token_num >= self.max_model_len:
raise ValueError(
f"This model's maximum context length is "
f"{self.max_model_len} tokens. However, you requested "
f"{token_num} tokens in the messages, "
f"Please reduce the length of the messages.")
elif token_num + max_tokens > self.max_model_len:
raise ValueError(
f"This model's maximum context length is "
f"{self.max_model_len} tokens. However, you requested "
f"{max_tokens + token_num} tokens "
f"({token_num} in the messages, "
f"{max_tokens} in the completion). "
f"Please reduce the length of the messages or completion.")
return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
def _log_inputs(
self,
request_id: str,
inputs,
params: Optional[Union[SamplingParams, PoolingParams,
BeamSearchParams]],
lora_request: Optional[LoRARequest],
) -> None:
# move to position where before use request_logger
# if self.request_logger is None:
# return
# if self.model_config.pooler_config is not None, task is embedding , not generation task
if self.model_config.pooler_config:
return
prompt, prompt_token_ids, prompt_embeds = None, None, None
if isinstance(inputs, str):
prompt = inputs
elif isinstance(inputs, list):
prompt_token_ids = inputs
else:
prompt = getattr(inputs, 'prompt', None)
prompt_token_ids = getattr(inputs, 'prompt_token_ids', None)
# generation_config 读取的惩罚信息, 如果有,则警告并且修改
if (params.repetition_penalty is not None and abs(params.repetition_penalty - 1.0) >= _SAMPLING_EPS):
logger.warning(
"\033[93mWARNING \033[0m"
": Unsupport penalty for sampler"
f"params.repetition_penalty: {params.repetition_penalty} and "
"Please set attrs: extra_body = {\'repetition_penalty\': 1.0}\n"
"Now set: repetition_penalty: 1.0"
)
# params.presence_penalty = 0
# params.frequency_penalty = 0
params.repetition_penalty = 1
if hasattr(params, "min_p") and params.min_p is not None and params.min_p > _SAMPLING_EPS:
logger.warning(f"\033[93mWARNING \033[0m : unsupport min_p {params.min_p} for sampler")
params.min_p = 0
if self.request_logger is None:
return
self.request_logger.log_inputs(
request_id,
prompt,
prompt_token_ids,
prompt_embeds,
params=params,
lora_request=lora_request,
)

View File

@@ -0,0 +1,127 @@
import asyncio
import io
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Annotated, Optional, Union
import pybase64
import torch
from pydantic import Field
from vllm.config import ModelConfig
from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.inputs.parse import parse_and_batch_prompt
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import AsyncMicrobatchTokenizer
class BaseRenderer(ABC):
"""
Base class for unified input processing and rendering.
The Renderer serves as a unified input processor that consolidates
tokenization, chat template formatting, and multimodal input handling
into a single component.
It converts high-level API requests (OpenAI-style JSON) into token IDs and
multimodal features ready for engine consumption.
Key responsibilities:
- Convert text prompts to token sequences with proper special tokens
- Apply chat templates and format conversations
- Handle multimodal inputs (images, audio, etc.) when applicable
- Manage prompt truncation and length validation
- Provide clean separation between API layer and engine core
"""
@classmethod
def load_prompt_embeds(
cls,
prompt_embeds: Union[bytes, list[bytes]],
deepstack_input_embeds: Optional[dict[str, Union[bytes, str]]] = None,
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=0)]] = None,
cache_salt: Optional[str] = None,
) -> list[EngineEmbedsPrompt]:
"""Load and validate base64-encoded embeddings into prompt objects."""
def _load_and_validate_embed(embed: bytes) -> EngineEmbedsPrompt:
tensor = torch.load(
io.BytesIO(pybase64.b64decode(embed, validate=True)),
weights_only=True,
map_location=torch.device("cpu"),
)
assert isinstance(tensor, torch.Tensor) and tensor.dtype in (
torch.float32,
torch.bfloat16,
torch.float16,
)
tensor = tensor.to_dense()
if tensor.dim() > 2:
tensor = tensor.squeeze(0)
assert tensor.dim() == 2
if truncate_prompt_tokens is not None:
tensor = tensor[-truncate_prompt_tokens:]
embeds_prompt = EngineEmbedsPrompt(prompt_embeds=tensor)
if cache_salt is not None:
embeds_prompt["cache_salt"] = cache_salt
if deepstack_input_embeds is not None:
all_tensor = []
from vllm.sequence import IntermediateTensors
tensor_dict = torch.load(
io.BytesIO(pybase64.b64decode(deepstack_input_embeds, validate=True))
)
for k in tensor_dict:
all_tensor.append(tensor_dict[k].unsqueeze(0))
all_tensor = torch.concatenate(all_tensor, 0)
embeds_prompt["deepstack_input_embeds"] = all_tensor #IntermediateTensors(tensors=tensor_dict)
return embeds_prompt
if isinstance(prompt_embeds, list):
return [_load_and_validate_embed(embed) for embed in prompt_embeds]
return [_load_and_validate_embed(prompt_embeds)]
class CompletionRenderer(BaseRenderer):
async def render_prompt_and_embeds(
self,
*,
prompt_or_prompts: Optional[Union[str, list[str], list[int],
list[list[int]]]] = None,
prompt_embeds: Optional[Union[bytes, list[bytes]]] = None,
deepstack_input_embeds: Optional[Union[bytes, list[bytes]]] = None,
config: "RenderConfig",
) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]:
"""
Render text/token prompts and/or precomputed embedding prompts. At
least one of `prompt_or_prompts` or `prompt_embeds` must be provided.
"""
truncate_prompt_tokens = self._validate_and_normalize_truncate_tokens(
config.truncate_prompt_tokens, config.max_length)
if truncate_prompt_tokens == 0:
return []
rendered: list[Union[EngineTokensPrompt, EngineEmbedsPrompt]] = []
if prompt_embeds is not None:
rendered.extend(
self.load_prompt_embeds(prompt_embeds, deepstack_input_embeds, truncate_prompt_tokens,
config.cache_salt))
if prompt_or_prompts is None or prompt_or_prompts == "":
return rendered
token_prompts = await self.render_prompt(
prompt_or_prompts=prompt_or_prompts,
config=config,
)
rendered.extend(token_prompts)
return rendered

View File

View File

@@ -0,0 +1,20 @@
import asyncio
from typing import List
from vllm.v1.outputs import PoolerOutput, SamplerOutput
from vllm.sequence import ExecuteModelRequest
# class DistributedExecutorBase():
# """Abstract superclass of distributed executor implementations."""
async def execute_model_async(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
if self.parallel_worker_tasks is None:
# Start model execution loop running in the parallel workers
self.parallel_worker_tasks = asyncio.create_task(
self._start_worker_execution_loop())
await asyncio.sleep(0)
# Only the driver worker returns the sampling results.
await asyncio.sleep(0)
return await self._driver_execute_model_async(execute_model_req)

View File

Binary file not shown.

View File

@@ -0,0 +1,55 @@
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, Union, cast
import torch
from typing_extensions import NotRequired, TypedDict, TypeIs, TypeVar
if TYPE_CHECKING:
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalInputs,
MultiModalUUIDDict)
class EmbedsPrompt(TypedDict):
"""Schema for a prompt provided via token embeddings."""
prompt_embeds: torch.Tensor
"""The embeddings of the prompt."""
from vllm.sequence import IntermediateTensors
deepstack_input_embeds: Optional[IntermediateTensors]
cache_salt: NotRequired[str]
"""
Optional cache salt to be used for prefix caching.
"""
class EmbedsInputs(TypedDict):
"""Represents embeddings-based inputs."""
type: Literal["embeds"]
"""The type of inputs."""
prompt_embeds: torch.Tensor
"""The embeddings of the prompt."""
deepstack_input_embeds: torch.Tensor
cache_salt: NotRequired[str]
"""
Optional cache salt to be used for prefix caching.
"""
def embeds_inputs(
prompt_embeds: torch.Tensor,
deepstack_input_embeds: torch.Tensor,
cache_salt: Optional[str] = None,
) -> EmbedsInputs:
"""Construct [`EmbedsInputs`][vllm.inputs.data.EmbedsInputs] from optional
values."""
inputs = EmbedsInputs(type="embeds", prompt_embeds=prompt_embeds, deepstack_input_embeds=deepstack_input_embeds)
if cache_salt is not None:
inputs["cache_salt"] = cache_salt
return inputs

View File

@@ -0,0 +1,54 @@
from collections.abc import Mapping
from typing import Any, Optional, Union, cast
from typing_extensions import assert_never
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.cache import BaseMultiModalProcessorCache
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalInputs, MultiModalUUIDDict)
from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.transformers_utils.tokenizer import AnyTokenizer
from .data import EmbedsInputs, EmbedsPrompt, embeds_inputs
logger = init_logger(__name__)
class InputPreprocessor:
def _process_embeds(
self,
parsed_content: EmbedsPrompt,
) -> EmbedsInputs:
if not self.model_config.enable_prompt_embeds:
raise ValueError("You must set `--enable-prompt-embeds` to input "
"`prompt_embeds`.")
prompt_embeds = parsed_content["prompt_embeds"]
deepstack_input_embeds = None
if 'deepstack_input_embeds' in parsed_content:
deepstack_input_embeds = parsed_content["deepstack_input_embeds"]
# prompt_embeds must be (seq_len, hidden_size), but if the user
# passes in a batch of size 1, i.e. (1, seq_len, hidden_size),
# we can unambiguously process the intent by squeezing the batch
# dimension.
if prompt_embeds.ndim == 3:
prompt_embeds = prompt_embeds.squeeze(dim=0)
if prompt_embeds.ndim != 2:
raise ValueError(
"prompt_embeds must be of shape (seq_len, hidden_size).")
# Tensors must be on CPU for serialization between processes
# in the MsgpackEncoder. Casting to CPU here ensures that there is no
# hidden device transfer in the critical path of generation.
prompt_embeds = prompt_embeds.cpu()
return embeds_inputs(prompt_embeds=prompt_embeds,
deepstack_input_embeds=deepstack_input_embeds,
cache_salt=parsed_content.get("cache_salt"))

View File

@@ -0,0 +1,38 @@
import torch.nn as nn
from vllm.logger import init_logger
logger = init_logger(__name__)
class CustomOp(nn.Module):
def forward_vacc(self, *args, **kwargs):
raise NotImplementedError
def dispatch_forward(self):
# NOTE(woosuk): Here we assume that vLLM was built for only one
# specific backend. Currently, we do not support dynamic dispatching.
enabled = self.enabled()
logger.debug("custom op %s %s", self.__class__.name,
"enabled" if enabled else "disabled")
if not enabled:
return self.forward_native
return self.forward
if current_platform.is_rocm():
return self.forward_hip
elif current_platform.is_cpu():
return self.forward_cpu
elif current_platform.is_hpu():
return self.forward_hpu
elif current_platform.is_tpu():
return self.forward_tpu
elif current_platform.is_xpu():
return self.forward_xpu
elif current_platform.is_vacc():
return self.forward
else:
return self.forward_cuda

Some files were not shown because too many files have changed in this diff Show More