init
This commit is contained in:
66
vllm_vacc/__init__.py
Normal file
66
vllm_vacc/__init__.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
def generate_bootinfo():
|
||||
from pathlib import Path
|
||||
import datetime
|
||||
import os
|
||||
workspace_path = Path.cwd()
|
||||
bootinfo_time = str(datetime.datetime.now())
|
||||
bootinfo_tag_vccl = "[vccl]"
|
||||
bootinfo_seq = "@@@"
|
||||
|
||||
bootinfo_config = f'{workspace_path}/.bootinfos'
|
||||
bootinfo_inited = os.path.exists(bootinfo_config)
|
||||
|
||||
bootinfo_start_up_times = 0
|
||||
if bootinfo_inited:
|
||||
try:
|
||||
with open(bootinfo_config) as w:
|
||||
current_bootinfos = w.readline()
|
||||
# print("current_bootinfos:",current_bootinfos)
|
||||
bootinfo_start_up_times = int(current_bootinfos.split('cycle')[1].strip())
|
||||
except Exception as e:
|
||||
print("[WARN] read bootinfo fail, caused by ", e)
|
||||
# limit max run times
|
||||
if bootinfo_start_up_times > 10000000:
|
||||
bootinfo_start_up_times = 0
|
||||
|
||||
current_bootinfos = f'{bootinfo_tag_vccl}{bootinfo_seq}{bootinfo_time} cycle {bootinfo_start_up_times + 1}'
|
||||
try:
|
||||
with open(bootinfo_config, 'w') as w:
|
||||
w.write(current_bootinfos)
|
||||
except Exception as e:
|
||||
print("[WARN] write bootinfo fail, caused by ", e)
|
||||
|
||||
def exec_patching():
|
||||
import sys
|
||||
from vllm_vacc.vllm_patch import VllmPatchManager, patch_vllm, patch_vllm_v1, patch_torch, regist_mock_module
|
||||
import sys
|
||||
|
||||
vpm = VllmPatchManager
|
||||
|
||||
if vpm.patched:
|
||||
return
|
||||
|
||||
regist_mock_module()
|
||||
patch_torch()
|
||||
patch_vllm(vpm)
|
||||
patch_vllm_v1(vpm)
|
||||
|
||||
vpm.apply_patches()
|
||||
generate_bootinfo()
|
||||
|
||||
if "vllm.executor.executor_base" in sys.modules:
|
||||
del sys.modules["vllm.executor.executor_base"]
|
||||
|
||||
def register():
|
||||
return "vllm_vacc.platform.VaccPlatform"
|
||||
|
||||
def register_model():
|
||||
exec_patching()
|
||||
|
||||
# from .models import register_model
|
||||
# register_model()
|
||||
return
|
||||
|
||||
# exec_patching()
|
||||
BIN
vllm_vacc/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
vllm_vacc/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_vacc/__pycache__/patch_util.cpython-312.pyc
Normal file
BIN
vllm_vacc/__pycache__/patch_util.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_vacc/__pycache__/platform.cpython-312.pyc
Normal file
BIN
vllm_vacc/__pycache__/platform.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_vacc/__pycache__/vllm_patch.cpython-312.pyc
Normal file
BIN
vllm_vacc/__pycache__/vllm_patch.cpython-312.pyc
Normal file
Binary file not shown.
299
vllm_vacc/patch_util.py
Normal file
299
vllm_vacc/patch_util.py
Normal file
@@ -0,0 +1,299 @@
|
||||
import importlib
|
||||
import sys
|
||||
import pkgutil
|
||||
import types
|
||||
from typing import List
|
||||
|
||||
PATCH_ROOT = "vllm_vacc."
|
||||
|
||||
|
||||
def get_func_name(func):
|
||||
if isinstance(func, str):
|
||||
return func
|
||||
return ".".join((func.__module__, func.__qualname__))
|
||||
|
||||
|
||||
def dummy_function_wrapper(func_name):
|
||||
def dummy_function(*args, **kwargs):
|
||||
raise RuntimeError(f"function {func_name} no exist")
|
||||
|
||||
return dummy_function
|
||||
|
||||
|
||||
def dummy_jit(fn):
|
||||
def wrapper(*args, **kwargs):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class Patch:
|
||||
def __init__(self, orig_func_name, new_func, create_dummy):
|
||||
split_name = orig_func_name.rsplit(".", 1)
|
||||
if len(split_name) == 1:
|
||||
self.orig_module_name, self.orig_func_name = orig_func_name, None
|
||||
else:
|
||||
self.orig_module_name, self.orig_func_name = split_name
|
||||
self.orig_module = None
|
||||
self.orig_func = None
|
||||
|
||||
self.patch_func = None
|
||||
self.wrappers = []
|
||||
if new_func is None:
|
||||
new_func = dummy_function_wrapper(orig_func_name)
|
||||
self.set_patch_func(new_func)
|
||||
self.is_applied = False
|
||||
self.create_dummy = create_dummy
|
||||
|
||||
@property
|
||||
def orig_func_id(self):
|
||||
return id(self.orig_func)
|
||||
|
||||
@property
|
||||
def patch_func_id(self):
|
||||
return id(self.patch_func)
|
||||
|
||||
def set_patch_func(self, new_func, force_patch=False):
|
||||
if hasattr(new_func, "__name__") and new_func.__name__.endswith(
|
||||
("wrapper", "decorator")
|
||||
):
|
||||
self.wrappers.append(new_func)
|
||||
else:
|
||||
if self.patch_func and not force_patch:
|
||||
raise RuntimeError(
|
||||
f"The patch of '{self.orig_func_name}' ('{self.patch_func}') exist!"
|
||||
)
|
||||
self.patch_func = new_func
|
||||
self.is_applied = False
|
||||
|
||||
def apply_patch(self):
|
||||
if self.is_applied:
|
||||
return
|
||||
|
||||
self.orig_module, self.orig_func = Patch.parse_path(
|
||||
self.orig_module_name, self.orig_func_name, self.create_dummy
|
||||
)
|
||||
if self.patch_func is None:
|
||||
self.patch_func = self.orig_func
|
||||
|
||||
for wrapper in self.wrappers:
|
||||
self.patch_func = wrapper(self.patch_func)
|
||||
|
||||
if self.orig_func_name is not None:
|
||||
setattr(self.orig_module, self.orig_func_name, self.patch_func)
|
||||
for key, value in sys.modules.copy().items():
|
||||
# 遍历 pip 所有库, 然后 setattr, 有些库不匹配 可能会有问题, 这里是否可以优化 只遍历vllm相关
|
||||
try:
|
||||
if (
|
||||
self.orig_func_name is not None
|
||||
and hasattr(value, self.orig_func_name)
|
||||
and id(getattr(value, self.orig_func_name)) == self.orig_func_id
|
||||
):
|
||||
setattr(value, self.orig_func_name, self.patch_func)
|
||||
except:
|
||||
continue
|
||||
|
||||
self.is_applied = True
|
||||
|
||||
@staticmethod
|
||||
def parse_function(function_path: str, create_dummy):
|
||||
split_name = function_path.rsplit(".", 1)
|
||||
if len(split_name) == 1:
|
||||
orig_module_name, orig_func_name = function_path, None
|
||||
else:
|
||||
orig_module_name, orig_func_name = split_name
|
||||
return Patch.parse_path(orig_module_name, orig_func_name, create_dummy)[1]
|
||||
|
||||
@staticmethod
|
||||
def parse_path(module_path, function_name, create_dummy):
|
||||
from importlib.machinery import ModuleSpec
|
||||
|
||||
modules = module_path.split(".")
|
||||
for i in range(1, len(modules) + 1):
|
||||
parent = ".".join(modules[: i - 1])
|
||||
path = ".".join(modules[:i])
|
||||
try:
|
||||
importlib.import_module(path)
|
||||
except ModuleNotFoundError as e:
|
||||
if not parent or not hasattr(
|
||||
importlib.import_module(parent), modules[i - 1]
|
||||
):
|
||||
if not create_dummy:
|
||||
raise ModuleNotFoundError(e) from e
|
||||
sys.modules[path] = types.ModuleType(path)
|
||||
sys.modules[path].__file__ = "patch_tools.dummy_module.py"
|
||||
sys.modules[path].__spec__ = ModuleSpec(path, None)
|
||||
if parent:
|
||||
setattr(
|
||||
importlib.import_module(parent),
|
||||
modules[i - 1],
|
||||
sys.modules[path],
|
||||
)
|
||||
else:
|
||||
module = getattr(importlib.import_module(parent), modules[i - 1])
|
||||
if hasattr(module, function_name):
|
||||
return module, getattr(module, function_name)
|
||||
elif create_dummy:
|
||||
return module, dummy_function_wrapper(function_name)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Function '{function_name}' of '{module}' does not exist."
|
||||
) from e
|
||||
|
||||
if function_name is not None and not hasattr(
|
||||
sys.modules[module_path], function_name
|
||||
):
|
||||
setattr(sys.modules[module_path], function_name, None)
|
||||
return sys.modules[module_path], (
|
||||
getattr(sys.modules[module_path], function_name)
|
||||
if function_name is not None
|
||||
else None
|
||||
)
|
||||
|
||||
|
||||
class PatchManager:
|
||||
patches_info: dict = {}
|
||||
patched: bool = False
|
||||
|
||||
@classmethod
|
||||
def get_patch_info(cls):
|
||||
return cls.patches_info
|
||||
|
||||
@classmethod
|
||||
def register_patch(
|
||||
cls,
|
||||
orig_func_name,
|
||||
new_func=None,
|
||||
force_patch=False,
|
||||
create_dummy=False,
|
||||
allow_create=False,
|
||||
):
|
||||
"""
|
||||
if new_func is written via @wraps, its name must be ended with `wrapper` or `decorator`,
|
||||
also if it ends with `wrapper` or `decorator`, it must be written via `@wraps`
|
||||
"""
|
||||
if not cls._path_valid(orig_func_name):
|
||||
if not allow_create:
|
||||
raise ValueError(
|
||||
f"Module/function path '{orig_func_name}' does not exist, and allow_create=False."
|
||||
)
|
||||
|
||||
# if not create_dummy and not cls._path_valid(orig_func_name):
|
||||
# print(f"WARNING: path '{orig_func_name}' is not valid, skipped.")
|
||||
# return
|
||||
|
||||
patch_info = cls.get_patch_info()
|
||||
if orig_func_name not in patch_info:
|
||||
patch_info[orig_func_name] = Patch(orig_func_name, new_func, create_dummy)
|
||||
else:
|
||||
patch_info[orig_func_name].set_patch_func(new_func, force_patch)
|
||||
|
||||
@classmethod
|
||||
def _is_module(self, module_path):
|
||||
try:
|
||||
importlib.import_module(module_path)
|
||||
return True
|
||||
except ModuleNotFoundError:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def recursive_register_module(cls, module_path, allow_create=False):
|
||||
# replace whole submodule in a module
|
||||
assert cls._is_module(
|
||||
PATCH_ROOT + module_path
|
||||
), f'"{PATCH_ROOT}{module_path}" is not a valid module path. Only use this function to register module patches (not functions or classes).'
|
||||
|
||||
all_submodules = cls.enumerate_submodules(PATCH_ROOT + module_path)
|
||||
all_submodules = [submodule[len(PATCH_ROOT) :] for submodule in all_submodules]
|
||||
all_submodules = [module_path] + all_submodules
|
||||
|
||||
for module in all_submodules:
|
||||
try:
|
||||
importlib.import_module(module)
|
||||
except ModuleNotFoundError as e:
|
||||
pass
|
||||
sys.modules[module] = importlib.import_module(PATCH_ROOT + module)
|
||||
|
||||
@classmethod
|
||||
def batch_recursive_register_module(cls, module_paths, allow_create=False):
|
||||
for module_path in module_paths:
|
||||
cls.recursive_register_module(module_path, allow_create=allow_create)
|
||||
|
||||
@classmethod
|
||||
def enumerate_submodules(cls, module_name):
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
except ImportError as e:
|
||||
print(f"Error importing module {module_name}: {e}")
|
||||
return []
|
||||
|
||||
submodules = []
|
||||
for loader, submodule_name, is_pkg in pkgutil.walk_packages(
|
||||
module.__path__, module.__name__ + "."
|
||||
):
|
||||
submodules.append(submodule_name)
|
||||
|
||||
return submodules
|
||||
|
||||
@classmethod
|
||||
def batch_register_patch(
|
||||
cls,
|
||||
orig_func_names: List,
|
||||
force_patch=False,
|
||||
create_dummy=False,
|
||||
allow_create=False,
|
||||
):
|
||||
"""
|
||||
This function assumes all new_func are organized in same path like orig_func_names except prefixed with 'vastext.'
|
||||
"""
|
||||
for orig_func_name in orig_func_names:
|
||||
if not cls._path_valid(orig_func_name) and not allow_create:
|
||||
print(f"WARNING: path '{orig_func_name}' is not valid, skipped.")
|
||||
continue
|
||||
|
||||
wrapper_name = orig_func_name
|
||||
if cls._path_valid(PATCH_ROOT + wrapper_name + "_wrapper"):
|
||||
wrapper_name = wrapper_name + "_wrapper"
|
||||
assert cls._path_valid(
|
||||
PATCH_ROOT + wrapper_name
|
||||
), f"'{PATCH_ROOT}{wrapper_name}' or '{PATCH_ROOT}{wrapper_name}_wrapper' must be a valid module/function path. Try import {PATCH_ROOT}{wrapper_name} to see if other errors exist."
|
||||
new_func = Patch.parse_function(
|
||||
PATCH_ROOT + wrapper_name, create_dummy=False
|
||||
)
|
||||
if new_func is None:
|
||||
new_func = importlib.import_module(PATCH_ROOT + wrapper_name)
|
||||
# print(f">>> Register patch or function: '{orig_func_name}' -> '{new_func}'")
|
||||
cls.register_patch(
|
||||
orig_func_name,
|
||||
new_func,
|
||||
force_patch=force_patch,
|
||||
create_dummy=create_dummy,
|
||||
allow_create=allow_create,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _path_valid(cls, path):
|
||||
components = path.split(".")
|
||||
for i in range(len(components), 0, -1):
|
||||
module_name = ".".join(components[:i])
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
break
|
||||
except ImportError:
|
||||
continue
|
||||
else:
|
||||
return False
|
||||
|
||||
for component in components[i:]:
|
||||
if hasattr(module, component):
|
||||
module = getattr(module, component)
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def apply_patches(cls):
|
||||
patch_info = cls.get_patch_info()
|
||||
for patch in patch_info.values():
|
||||
patch.apply_patch()
|
||||
cls.patched = True
|
||||
163
vllm_vacc/platform.py
Normal file
163
vllm_vacc/platform.py
Normal file
@@ -0,0 +1,163 @@
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
# from .interface import Platform, PlatformEnum, _Backend
|
||||
from vllm.platforms.interface import Platform, PlatformEnum, _Backend
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig,ModelConfig
|
||||
else:
|
||||
VllmConfig = None
|
||||
ModelConfig = None
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class VaccPlatform(Platform):
|
||||
try:
|
||||
import torch_vacc
|
||||
is_vacc = True
|
||||
except Exception as e:
|
||||
assert False, f"error import torch_vacc: {e}"
|
||||
_enum = PlatformEnum.OOT
|
||||
device_name: str = "vacc"
|
||||
device_type: str = "vacc"
|
||||
dispatch_key: str = "PrivateUse1"
|
||||
ray_device_key: str = "GPU"
|
||||
device_control_env_var: str = "VACC_VISIBLE_MODULES"
|
||||
simple_compile_backend: str = "eager" # Disable torch.compile()
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
||||
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
||||
block_size: int, use_v1: bool,
|
||||
use_mla: bool, has_sink: bool, use_sparse: bool) -> str:
|
||||
if use_mla:
|
||||
logger.info("Using VACCMLA backend.")
|
||||
if use_v1:
|
||||
return "vllm_vacc.vllm.v1.attention.backends.vacc_mla.VACCMLABackend"
|
||||
return "vllm_vacc.vllm.attention.backends.vacc_mla.VACCMLABackend"
|
||||
if use_v1:
|
||||
return "vllm_vacc.vllm.v1.attention.backends.vacc_attn.VACCAttentionBackend"
|
||||
else:
|
||||
logger.info("Using VACCAttention backend.")
|
||||
return "vllm_vacc.vllm.attention.backends.vacc_attn.VACCAttentionBackend"
|
||||
|
||||
@classmethod
|
||||
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def inference_mode():
|
||||
return torch.no_grad()
|
||||
|
||||
@classmethod
|
||||
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||
import vllm.envs as envs
|
||||
|
||||
if vllm_config.kv_transfer_config:
|
||||
raise NotImplementedError("kv-transfer-config is not implemented for VACC")
|
||||
|
||||
cache_config = vllm_config.cache_config
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
if ((scheduler_config.chunked_prefill_enabled
|
||||
or cache_config.enable_prefix_caching)
|
||||
and cache_config.cache_dtype != "auto"):
|
||||
raise RuntimeError("Chunked-prefill and prefix-cache on the Vacc "
|
||||
"backend is not compatible with FP8 KV cache.")
|
||||
|
||||
|
||||
# scheduling_polity = scheduler_config.policy
|
||||
# model_config = vllm_config.model_config
|
||||
# use_async_output_proc = model_config.use_async_output_proc
|
||||
# if scheduling_polity == "priority" and use_async_output_proc: # probably a bug
|
||||
# logger.warning("WARNING scheduling_polity priority is not fully supported for VACC, "
|
||||
# "use fcfs instead automatically")
|
||||
# vllm_config.scheduler_config.scheduling_polity = "fcfs"
|
||||
|
||||
# if vllm_config.speculative_config is not None:
|
||||
# raise NotImplementedError(
|
||||
# "Speculative decoding is not implemented for VACC")
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
if parallel_config.worker_cls == "auto":
|
||||
if vllm_config.speculative_config:
|
||||
if envs.VLLM_USE_V1:
|
||||
parallel_config.worker_cls = "vllm_vacc.vllm.v1.worker.vacc_worker.VACCWorker"
|
||||
else:
|
||||
parallel_config.worker_cls = "vllm.spec_decode.spec_decode_worker.create_spec_worker"
|
||||
parallel_config.sd_worker_cls = "vllm_vacc.vllm.worker.vacc_worker.VACCWorker"
|
||||
else:
|
||||
if envs.VLLM_USE_V1:
|
||||
parallel_config.worker_cls = \
|
||||
"vllm_vacc.vllm.v1.worker.vacc_worker.VACCWorker"
|
||||
print('v1 VACCWorker')
|
||||
else:
|
||||
parallel_config.worker_cls = \
|
||||
"vllm_vacc.vllm.worker.vacc_worker.VACCWorker"
|
||||
|
||||
|
||||
# NOTE(kzawora): default block size for Gaudi should be 128
|
||||
# smaller sizes still work, but very inefficiently
|
||||
cache_config = vllm_config.cache_config
|
||||
if cache_config and cache_config.gpu_memory_utilization:
|
||||
logger.warning("WARNING gpu_memory_utilization is not supported on VACC")
|
||||
|
||||
# if cache_config and cache_config.enable_prefix_caching:
|
||||
# raise NotImplementedError("Prefix-caching is not implemented for VACC")
|
||||
|
||||
if cache_config and cache_config.block_size is None:
|
||||
cache_config.block_size = 16
|
||||
if (parallel_config.distributed_executor_backend == 'mp'
|
||||
and envs.VLLM_WORKER_MULTIPROC_METHOD == 'fork'):
|
||||
if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD",
|
||||
None) is not None:
|
||||
logger.warning("On VACC, VLLM_WORKER_MULTIPROC_METHOD=fork "
|
||||
"might cause application hangs on exit. Using "
|
||||
"VLLM_WORKER_MULTIPROC_METHOD=fork anyway, "
|
||||
"as it was explicitly requested.")
|
||||
else:
|
||||
logger.warning(
|
||||
"On VACC, VLLM_WORKER_MULTIPROC_METHOD=fork "
|
||||
"might cause application hangs on exit. Setting "
|
||||
"VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. "
|
||||
"To override that behavior, please set "
|
||||
"VLLM_WORKER_MULTIPROC_METHOD=fork explicitly.")
|
||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||
|
||||
@classmethod
|
||||
def is_pin_memory_available(cls):
|
||||
logger.warning("Pin memory is not supported on VACC.")
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_punica_wrapper(cls) -> str:
|
||||
return "vllm_vacc.vllm.lora.punica_wrapper.punica_vacc.PunicaWrapperVACC"
|
||||
|
||||
@classmethod
|
||||
def get_current_memory_usage(cls,
|
||||
device: Optional[torch.types.Device] = None
|
||||
) -> float:
|
||||
torch.vacc.reset_peak_memory_stats(device)
|
||||
return torch.vacc.max_memory_allocated(device)
|
||||
|
||||
@classmethod
|
||||
def use_all_gather(cls) -> bool:
|
||||
"""
|
||||
Whether to use allgather in LogitsProcessor to gather the logits.
|
||||
"""
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def supports_v1(cls, model_config: ModelConfig) -> bool:
|
||||
"""Returns whether the current platform can support v1 for the supplied
|
||||
model configuration.
|
||||
"""
|
||||
# return False # or export VLLM_USE_V1=0 to use v0
|
||||
if os.getenv("VLLM_USE_V1", 1) == '0':
|
||||
return False
|
||||
return True
|
||||
0
vllm_vacc/vllm/__init__.py
Normal file
0
vllm_vacc/vllm/__init__.py
Normal file
BIN
vllm_vacc/vllm/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_vacc/vllm/__pycache__/_custom_ops.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/__pycache__/_custom_ops.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_vacc/vllm/__pycache__/config.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/__pycache__/config.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_vacc/vllm/__pycache__/config_manager.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/__pycache__/config_manager.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_vacc/vllm/__pycache__/sequence.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/__pycache__/sequence.cpython-312.pyc
Normal file
Binary file not shown.
95
vllm_vacc/vllm/_custom_ops.py
Normal file
95
vllm_vacc/vllm/_custom_ops.py
Normal file
@@ -0,0 +1,95 @@
|
||||
|
||||
import contextlib
|
||||
import importlib
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch
|
||||
import torch.library
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.scalar_type import ScalarType
|
||||
|
||||
def cutlass_scaled_mm_vacc(a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""
|
||||
`cutlass_scaled_mm` implements a fused version of
|
||||
`output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)`
|
||||
where scale_a * a and scale_b * b are implemented using numpy-style
|
||||
broadcasting.
|
||||
|
||||
In order to support blockwise scaling like found in DeepSeek V3 we also
|
||||
support extended "group" broadcast rules. We extend the numpy-style
|
||||
broadcasting rules with the following rule:
|
||||
"if the extent of a dimension in the source shape is between 1 and
|
||||
corresponding extent in the target shape we repeat each element along
|
||||
that dimension src_shape[dim] // target_shape[dim] times consecutively"
|
||||
example if we have:
|
||||
a = [[1, 2], and target_shape = (2, 4)
|
||||
[3, 4]]
|
||||
then we would expand a to:
|
||||
a = [[1, 1, 2, 2],
|
||||
[3, 3, 4, 4]]
|
||||
currently we only support the case:
|
||||
scale_a.shape * [1, 128] == a.shape
|
||||
scale_b.shape * [128, 128] == b.shape
|
||||
"""
|
||||
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
|
||||
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
|
||||
assert bias is None or bias.shape[0] == b.shape[
|
||||
1] and bias.dtype == out_dtype
|
||||
|
||||
m = a.shape[0]
|
||||
n = b.shape[1]
|
||||
|
||||
if current_platform.is_rocm():
|
||||
triton_scaled_mm_module = importlib.import_module(
|
||||
"vllm.model_executor.layers.quantization.compressed_tensors."
|
||||
"triton_scaled_mm")
|
||||
triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
|
||||
return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
|
||||
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
|
||||
# print('a',a.shape,a.dtype) # torch.Size([8192, 3584]) torch.float8_e4m3fn
|
||||
# print('scale_a',scale_a.shape) #torch.Size([8192, 56])
|
||||
# print('b',b.shape,b.dtype) # torch.Size([3584, 1536]) torch.float8_e4m3fn
|
||||
# print('scale_b',scale_b.shape) #torch.Size([56, 12])
|
||||
|
||||
use_a32_w32 = True #反量化到fp32 计算 matmul
|
||||
|
||||
if use_a32_w32 or (b.shape[1]//scale_b.shape[1] != 128 or
|
||||
a.shape[1]//scale_a.shape[1] != 128 or
|
||||
b.shape[0]//scale_b.shape[0] != 128):
|
||||
# cutlass_scaled_mm 不支持非128的 quant block
|
||||
a1 = a.to(torch.float32).reshape(a.shape[0], scale_a.shape[1], -1)
|
||||
scale_a = scale_a.reshape(scale_a.shape[0], scale_a.shape[1], 1).to(torch.float32)
|
||||
a = (a1*scale_a).reshape(a.shape).contiguous()
|
||||
|
||||
b1 = b.to(torch.float32).reshape(scale_b.shape[0], b.shape[0]//scale_b.shape[0], scale_b.shape[1], b.shape[1]//scale_b.shape[1])
|
||||
scale_b = scale_b.reshape(scale_b.shape[0], 1, scale_b.shape[1], 1).to(torch.float32)
|
||||
b = (b1*scale_b).reshape(b.shape).contiguous()
|
||||
|
||||
out = a@b
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
return out.to(out_dtype)
|
||||
|
||||
torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def concat_and_cache_mla(
|
||||
kv_c: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
scale: torch.Tensor,
|
||||
) -> None:
|
||||
torch.vacc.concat_and_cache_attention(
|
||||
kv_c, k_pe, kv_cache, slot_mapping)
|
||||
0
vllm_vacc/vllm/attention/__init__.py
Normal file
0
vllm_vacc/vllm/attention/__init__.py
Normal file
BIN
vllm_vacc/vllm/attention/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/attention/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
0
vllm_vacc/vllm/attention/backends/__init__.py
Normal file
0
vllm_vacc/vllm/attention/backends/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
0
vllm_vacc/vllm/attention/backends/mla/__init__.py
Normal file
0
vllm_vacc/vllm/attention/backends/mla/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
202
vllm_vacc/vllm/attention/backends/mla/common.py
Normal file
202
vllm_vacc/vllm/attention/backends/mla/common.py
Normal file
@@ -0,0 +1,202 @@
|
||||
|
||||
import functools
|
||||
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Tuple,
|
||||
Type, TypeVar)
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import AttentionLayer
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionState, MLAAttentionImpl)
|
||||
from vllm.attention.backends.mla.common import MLACommonMetadata,triton_attention
|
||||
|
||||
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearBase, RowParallelLinear,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.attention.utils.fa_utils import get_flash_attn_version
|
||||
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
try:
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
is_vllm_fa = True
|
||||
except ImportError:
|
||||
is_vllm_fa = False
|
||||
try:
|
||||
# For rocm use upstream flash attention
|
||||
from vllm.attention.backends.flash_attn import flash_attn_varlen_func
|
||||
except ImportError:
|
||||
flash_attn_varlen_func = None
|
||||
|
||||
T = TypeVar("T", bound="MLACommonMetadata")
|
||||
|
||||
|
||||
class MLACommonImpl():
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
# blocksparse_params: Optional[Dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
# MLA Specific Arguments
|
||||
q_lora_rank: Optional[int],
|
||||
kv_lora_rank: int,
|
||||
qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int,
|
||||
qk_head_dim: int,
|
||||
v_head_dim: int,
|
||||
rotary_emb: RotaryEmbedding,
|
||||
# q_proj should be q_b_proj if q_lora_rank is not None, but from an
|
||||
# attention backend perspective we rely on the layer to pass in the
|
||||
# correct matrix
|
||||
q_proj: ColumnParallelLinear,
|
||||
kv_b_proj: ColumnParallelLinear,
|
||||
o_proj: RowParallelLinear,
|
||||
positions: torch.Tensor = None,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.qk_head_dim = qk_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
|
||||
self.rotary_emb = rotary_emb
|
||||
self.use_yarn_rope = isinstance(rotary_emb,
|
||||
DeepseekScalingRotaryEmbedding)
|
||||
self.q_proj = q_proj
|
||||
self.kv_b_proj = kv_b_proj
|
||||
self.o_proj = o_proj
|
||||
self.positions = positions
|
||||
|
||||
self.triton_fa_func = triton_attention
|
||||
# Handle the differences between the flash_attn_varlen from flash_attn
|
||||
# and the one from vllm_flash_attn. The former is used on RoCM and the
|
||||
# latter has an additional parameter to control FA2 vs FA3
|
||||
self.flash_attn_varlen_func = flash_attn_varlen_func
|
||||
self.vllm_flash_attn_version = get_flash_attn_version()
|
||||
if self.vllm_flash_attn_version is not None:
|
||||
self.flash_attn_varlen_func = \
|
||||
functools.partial(flash_attn_varlen_func,
|
||||
fa_version=self.vllm_flash_attn_version)
|
||||
|
||||
# For MLA the v head dim is smaller than qk head dim so we pad out
|
||||
# v with 0s to match the qk head dim for attention backends that do
|
||||
# not support different headdims
|
||||
# We don't need to pad V if we are on a hopper system with FA3
|
||||
self._pad_v = self.vllm_flash_attn_version is None or not (
|
||||
self.vllm_flash_attn_version == 3
|
||||
and current_platform.get_device_capability()[0] == 9)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
hidden_states_or_q_c: torch.Tensor, # query in unified attn
|
||||
k_c_normed: torch.Tensor, # key in unified attn
|
||||
k_pe: torch.Tensor, # value in unified attn
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: T,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if output is not None:
|
||||
raise NotImplementedError(
|
||||
"output is not yet supported for MLAImplBase")
|
||||
|
||||
# if attn_metadata.is_profile_run and \
|
||||
# attn_metadata.context_chunk_workspace is not None:
|
||||
# # During the profile run try to simulate to worse case output size
|
||||
# # for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context`
|
||||
# # since this can be large
|
||||
# _ = torch.empty(
|
||||
# (attn_metadata.context_chunk_workspace.shape[0],
|
||||
# self.num_heads, self.qk_nope_head_dim + self.v_head_dim),
|
||||
# device=k_c_normed.device,
|
||||
# dtype=k_c_normed.dtype,
|
||||
# )
|
||||
|
||||
has_decode = attn_metadata.decode_metadata is not None
|
||||
has_prefill = attn_metadata.prefill_metadata is not None
|
||||
|
||||
# Restore head dim (for rotary embedding)
|
||||
k_pe = k_pe.unsqueeze(1)
|
||||
# assert hasattr(attn_metadata, "input_positions")
|
||||
if self.positions is not None:
|
||||
positions = self.positions
|
||||
elif hasattr(attn_metadata, "input_positions"):
|
||||
positions = attn_metadata.input_positions
|
||||
else:
|
||||
raise ValueError('no positions')
|
||||
|
||||
|
||||
num_prefill_tokens: int = attn_metadata.num_prefill_tokens
|
||||
|
||||
decode_hs_or_q_c = hidden_states_or_q_c[num_prefill_tokens:]
|
||||
decode_k_pe = k_pe[num_prefill_tokens:]
|
||||
decode_input_positions = \
|
||||
positions[num_prefill_tokens:]
|
||||
|
||||
prefill_hs_or_q_c = hidden_states_or_q_c[:num_prefill_tokens]
|
||||
prefill_k_pe = k_pe[:num_prefill_tokens]
|
||||
prefill_input_positions = \
|
||||
positions[:num_prefill_tokens]
|
||||
prefill_k_c_normed = k_c_normed[:num_prefill_tokens]
|
||||
|
||||
if has_decode:
|
||||
decode_ql_nope, decode_q_pe = \
|
||||
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
|
||||
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
|
||||
decode_input_positions, decode_q_pe, decode_k_pe)
|
||||
|
||||
if has_prefill:
|
||||
prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\
|
||||
.view(-1, self.num_heads, self.qk_head_dim)
|
||||
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
|
||||
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
|
||||
prefill_input_positions, prefill_q_pe, prefill_k_pe)
|
||||
|
||||
# write the latent and rope to kv cache
|
||||
if kv_cache.numel() > 0:
|
||||
ops.concat_and_cache_mla(
|
||||
k_c_normed,
|
||||
k_pe.squeeze(1),
|
||||
kv_cache,
|
||||
attn_metadata.slot_mapping.flatten(),
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
scale=layer._k_scale,
|
||||
)
|
||||
|
||||
# output = torch.empty(attn_metadata.num_prefill_tokens +
|
||||
# attn_metadata.num_decode_tokens,
|
||||
# self.o_proj.output_size,
|
||||
# device=hidden_states_or_q_c.device,
|
||||
# dtype=hidden_states_or_q_c.dtype)
|
||||
if has_prefill:
|
||||
return self._forward_prefill(
|
||||
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
|
||||
attn_metadata)
|
||||
|
||||
if has_decode:
|
||||
return self._forward_decode(
|
||||
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata)
|
||||
|
||||
assert False, "mla forward need prefill or decode function"
|
||||
return None
|
||||
390
vllm_vacc/vllm/attention/backends/mla/utils.py
Normal file
390
vllm_vacc/vllm/attention/backends/mla/utils.py
Normal file
@@ -0,0 +1,390 @@
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Generic, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import QuantizationStrategy
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import (AttentionLayer,
|
||||
AttentionMetadata,
|
||||
MLAAttentionImpl, T)
|
||||
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearBase, RowParallelLinear,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
||||
CompressedTensorsLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsW8A8Fp8)
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
||||
# from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
# apply_fp8_linear_generic, current_platform_fp8_dtype, is_fp8)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
scaled_dequantize, scaled_quantize)
|
||||
import os
|
||||
|
||||
W_Q_W_QR_WUV_WUK_USE_FP8 = True
|
||||
|
||||
class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
|
||||
def is_layer_fp8(layer: LinearBase) -> bool:
|
||||
return isinstance(layer.quant_method, Fp8LinearMethod) or\
|
||||
(isinstance(layer.quant_method, CompressedTensorsLinearMethod)\
|
||||
and isinstance(layer.scheme, CompressedTensorsW8A8Fp8))
|
||||
|
||||
def quantization_scheme_supported(layer: LinearBase) -> bool:
|
||||
return isinstance(layer.quant_method, UnquantizedLinearMethod) or \
|
||||
is_layer_fp8(layer)
|
||||
|
||||
# TODO(lucas) This is very gross, we need a more wide scale refactor of
|
||||
# all the FP8 code with a more standard way of
|
||||
# defining schemes/group-shapes, we should also potentially force
|
||||
# quant_methods to support a decompress function
|
||||
#
|
||||
# returns input_group_shape, weight_group_shape
|
||||
def get_scale_group_shapes_for_fp8(layer: LinearBase) -> \
|
||||
Tuple[Tuple[int, int], Tuple[int, int]]:
|
||||
if isinstance(layer.quant_method, Fp8LinearMethod):
|
||||
if layer.quant_method.block_quant is not None:
|
||||
weight_block_size = \
|
||||
layer.quant_method.quant_config.weight_block_size
|
||||
# per-token-group (1, X), block-quantized (X, Y)
|
||||
return (1, weight_block_size[-1]), weight_block_size
|
||||
else:
|
||||
return (-1, -1), (-1, -1) # per-tensor, per-tensor
|
||||
elif isinstance(layer.quant_method, CompressedTensorsLinearMethod)\
|
||||
and isinstance(layer.scheme, CompressedTensorsW8A8Fp8):
|
||||
# this is hacky but we always assume the for
|
||||
# CompressedTensorsW8A8Fp8 the input is dynamic per-token
|
||||
# we ignore if it is static-per-tensor since we are going to
|
||||
# requantize after later anyways
|
||||
strategy = layer.scheme.strategy
|
||||
if strategy == QuantizationStrategy.TENSOR:
|
||||
return (1, -1), (-1, -1) # per-token, per-tensor
|
||||
elif strategy == QuantizationStrategy.CHANNEL:
|
||||
return (1, -1), (-1, 1) # per-token, per-channel
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"QuantizationStrategy.{strategy} is not supported for "
|
||||
"fp8 MLA, please run with VLLM_MLA_DISABLE=1")
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Can't determine scale group shapes for "
|
||||
f"{layer.quant_method}, please run with VLLM_MLA_DISABLE=1"
|
||||
)
|
||||
|
||||
def get_scales(layer: LinearBase) -> torch.Tensor:
|
||||
if hasattr(layer, "weight_scale_inv"):
|
||||
return layer.weight_scale_inv
|
||||
return layer.weight_scale
|
||||
|
||||
def get_fp8_layer_weight(layer: LinearBase):
|
||||
if is_layer_fp8(layer):
|
||||
if isinstance(layer.quant_method, \
|
||||
CompressedTensorsLinearMethod) and \
|
||||
isinstance(layer.scheme, CompressedTensorsW8A8Fp8):
|
||||
# NOTE(lucas): note sure why but `CompressedTensorsW8A8Fp8`
|
||||
# seems to store weights as (input, output) instead of
|
||||
# (output, input) so we need to transpose
|
||||
weight = layer.weight.T # standardize to (output, input)
|
||||
else:
|
||||
weight = layer.weight
|
||||
_, weight_scale_group_shape = \
|
||||
get_scale_group_shapes_for_fp8(layer)
|
||||
scales = get_scales(layer) # 已经expand过了
|
||||
weight_scale_group_shape=weight_scale_group_shape.copy() #config中读出来的[128,128], 需要 .copy(), 否则会把config改掉
|
||||
|
||||
# 重新校准一下 weight_scale_group_shape
|
||||
if weight.shape[0] // scales.shape[0] != weight_scale_group_shape[0]:
|
||||
weight_scale_group_shape[0] = weight.shape[0] // scales.shape[0]
|
||||
|
||||
if weight.shape[1] // scales.shape[1] != weight_scale_group_shape[1]:
|
||||
weight_scale_group_shape[1] = weight.shape[1] // scales.shape[1]
|
||||
|
||||
return weight, scales
|
||||
else:
|
||||
return layer.weight, None
|
||||
|
||||
def get_fp8_layer_weight_test(layer: LinearBase):
|
||||
if is_layer_fp8(layer):
|
||||
if isinstance(layer.quant_method, \
|
||||
CompressedTensorsLinearMethod) and \
|
||||
isinstance(layer.scheme, CompressedTensorsW8A8Fp8):
|
||||
# NOTE(lucas): note sure why but `CompressedTensorsW8A8Fp8`
|
||||
# seems to store weights as (input, output) instead of
|
||||
# (output, input) so we need to transpose
|
||||
weight = layer.weight.T # standardize to (output, input)
|
||||
else:
|
||||
weight = layer.weight
|
||||
_, weight_scale_group_shape = \
|
||||
get_scale_group_shapes_for_fp8(layer)
|
||||
scales = get_scales(layer) # 已经expand过了
|
||||
weight_scale_group_shape=weight_scale_group_shape.copy() #config中读出来的[128,128], 需要 .copy(), 否则会把config改掉
|
||||
|
||||
# 重新校准一下 weight_scale_group_shape
|
||||
if weight.shape[0] // scales.shape[0] != weight_scale_group_shape[0]:
|
||||
weight_scale_group_shape[0] = weight.shape[0] // scales.shape[0]
|
||||
|
||||
if weight.shape[1] // scales.shape[1] != weight_scale_group_shape[1]:
|
||||
weight_scale_group_shape[1] = weight.shape[1] // scales.shape[1]
|
||||
|
||||
# for test
|
||||
weight = scaled_dequantize(weight, scales, weight_scale_group_shape)
|
||||
# print(f'{weight.shape}, {scales.shape}, {weight_scale_group_shape}')
|
||||
return weight, scales
|
||||
else:
|
||||
return layer.weight, None
|
||||
|
||||
def check_eq(name, tensor0, tensor1):
|
||||
assert tensor0.shape == tensor1.shape
|
||||
isEqual = torch.equal(tensor0.reshape([-1]).float(), tensor1.reshape([-1]).float())
|
||||
print(f"{os.getpid()} check {name} {tensor0.shape} equal: {isEqual}")
|
||||
return isEqual
|
||||
|
||||
def get_and_maybe_dequant_weights(layer: LinearBase):
|
||||
if is_layer_fp8(layer):
|
||||
if isinstance(layer.quant_method, \
|
||||
CompressedTensorsLinearMethod) and \
|
||||
isinstance(layer.scheme, CompressedTensorsW8A8Fp8):
|
||||
# NOTE(lucas): note sure why but `CompressedTensorsW8A8Fp8`
|
||||
# seems to store weights as (input, output) instead of
|
||||
# (output, input) so we need to transpose
|
||||
weight = layer.weight.T # standardize to (output, input)
|
||||
else:
|
||||
weight = layer.weight
|
||||
_, weight_scale_group_shape = \
|
||||
get_scale_group_shapes_for_fp8(layer)
|
||||
scales = get_scales(layer) # 已经expand过了
|
||||
weight_scale_group_shape=weight_scale_group_shape.copy() #config中读出来的[128,128], 需要 .copy(), 否则会把config改掉
|
||||
|
||||
# 重新校准一下 weight_scale_group_shape
|
||||
if weight.shape[0] // scales.shape[0] != weight_scale_group_shape[0]:
|
||||
weight_scale_group_shape[0] = weight.shape[0] // scales.shape[0]
|
||||
|
||||
if weight.shape[1] // scales.shape[1] != weight_scale_group_shape[1]:
|
||||
weight_scale_group_shape[1] = weight.shape[1] // scales.shape[1]
|
||||
|
||||
return scaled_dequantize(weight, scales,
|
||||
weight_scale_group_shape)
|
||||
else:
|
||||
return layer.weight
|
||||
|
||||
if not (quantization_scheme_supported(self.kv_b_proj) and\
|
||||
quantization_scheme_supported(self.q_proj) and\
|
||||
quantization_scheme_supported(self.o_proj)):
|
||||
raise NotImplementedError(
|
||||
"Only FP8 and UnquantizedLinearMethod are supported for MLA"
|
||||
", please run with VLLM_MLA_DISABLE=1")
|
||||
|
||||
weight_dtype = self.kv_b_proj.weight.dtype
|
||||
assert self.o_proj.weight.dtype == weight_dtype
|
||||
assert self.q_proj.weight.dtype == weight_dtype
|
||||
|
||||
if W_Q_W_QR_WUV_WUK_USE_FP8: #and not envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
|
||||
# 512,1024(=4x256)
|
||||
kv_b_proj_weight, kv_b_proj_scale = \
|
||||
[t.T for t in get_fp8_layer_weight(self.kv_b_proj)]
|
||||
|
||||
# kv_b_proj_weight = kv_b_proj_weight.transpose(-1,-2).contiguous().transpose(-1,-2)
|
||||
N, K = kv_b_proj_weight.shape[0], kv_b_proj_weight.shape[1]
|
||||
|
||||
# 512,1024 => 512,4,256
|
||||
kv_b_proj_weight = kv_b_proj_weight.view(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads,
|
||||
self.qk_nope_head_dim + self.v_head_dim,
|
||||
)
|
||||
|
||||
kv_b_proj_scale = kv_b_proj_scale.view(
|
||||
kv_b_proj_scale.shape[0] * self.kv_lora_rank // N,
|
||||
self.num_heads,
|
||||
kv_b_proj_scale.shape[1] * N // (self.kv_lora_rank * self.num_heads),
|
||||
)
|
||||
|
||||
W_UK, W_UV = kv_b_proj_weight.split(
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
W_UK = W_UK.contiguous()
|
||||
|
||||
scale_0 = kv_b_proj_scale.shape[-1] * self.qk_nope_head_dim // (self.qk_nope_head_dim + self.v_head_dim)
|
||||
scale_1 = kv_b_proj_scale.shape[-1] - scale_0
|
||||
|
||||
W_UK_scale, W_UV_scale = kv_b_proj_scale.split(
|
||||
[scale_0, scale_1], dim=-1)
|
||||
W_UK_scale = W_UK_scale.view(W_UK_scale.shape[0], -1).unsqueeze(-1).contiguous()
|
||||
W_UV_scale = W_UV_scale.view(W_UV_scale.shape[0], -1).unsqueeze(-1)
|
||||
|
||||
# weight: [1536, 768] scale: 12,6
|
||||
q_proj_weight, q_proj_scale = \
|
||||
[t.T for t in get_fp8_layer_weight(self.q_proj)]
|
||||
|
||||
#self.W_Q_QR = q_proj_weight.contiguous().transpose(-2,-1).contiguous().transpose(-2,-1)
|
||||
#self.W_Q_QR_scales = q_proj_scale.reshape(12, 6, 1).repeat(1, 1, 4).reshape(12, -1).contiguous().transpose(-2,-1).contiguous().transpose(-2,-1)
|
||||
|
||||
q_proj_weight = q_proj_weight\
|
||||
.view(-1, self.num_heads, self.qk_head_dim)
|
||||
# w_q[1536, 512] + w_qr[1536, 256]
|
||||
W_Q = q_proj_weight[..., :self.qk_nope_head_dim].flatten(start_dim=1)
|
||||
W_QR = q_proj_weight[..., self.qk_nope_head_dim:]\
|
||||
.flatten(start_dim=1).contiguous()
|
||||
# w_q_scale 12,16 + w_qr_scale 12,8
|
||||
# expand: 12,6(4+2) -> 12,24(16+8)
|
||||
# Q_scale: [s0x4, s1x2, s2x2, s3x4, s4x2, s5x2]
|
||||
repeat_pattern = torch.tensor([4, 2, 2, 4, 2, 2], device=q_proj_scale.device)
|
||||
W_Q_scale = torch.repeat_interleave(q_proj_scale, repeat_pattern, dim=1)
|
||||
# Q_R_scale: [s1x2, s2x2, s4x2, s5x2]
|
||||
selected_indices = [1, 2, 4, 5]
|
||||
repeat_times = 2
|
||||
selected = q_proj_scale[:, selected_indices]
|
||||
W_QR_scale = selected.repeat_interleave(repeat_times, dim=1)
|
||||
|
||||
# temp_WQ_Scale = W_Q_scale.reshape(12, 4, -1).contiguous()
|
||||
# temp_W_QR_scale = W_QR_scale.reshape(12, 4, -1).contiguous()
|
||||
# temp_scale = torch.cat([temp_WQ_Scale, temp_W_QR_scale], dim=2).contiguous().reshape(12, -1).contiguous().transpose(-2,-1).contiguous().transpose(-2,-1)
|
||||
# self.W_Q_QR_scales = temp_scale
|
||||
# print("W_Q_scale:", W_Q_scale.shape)
|
||||
# print("W_QR_scale:", W_QR_scale.shape)
|
||||
# print("temp_scale:", temp_scale.shape)
|
||||
# exit(0)
|
||||
|
||||
# Note: to be vnnl compatible
|
||||
# 1. expand w_uv scale for core split friendly
|
||||
if W_UV.shape[-1] % 4 == 0:
|
||||
W_UV_scale = W_UV_scale.expand((W_UV_scale.shape[0], W_UV_scale.shape[1], 4))
|
||||
# 2. change w_q, w_qr, w_uv weight&scale to K-contiguous (shape unchanged)
|
||||
W_Q = W_Q.transpose(-2,-1).contiguous().transpose(-2,-1)
|
||||
W_Q_scale = W_Q_scale.transpose(-2,-1).contiguous().transpose(-2,-1)
|
||||
W_QR = W_QR.transpose(-2,-1).contiguous().transpose(-2,-1)
|
||||
W_QR_scale = W_QR_scale.transpose(-2,-1).contiguous().transpose(-2,-1)
|
||||
|
||||
W_UV = W_UV.permute(2,1,0).contiguous().permute(2,1,0)
|
||||
W_UV_scale = W_UV_scale.permute(2,1,0).contiguous().permute(2,1,0)
|
||||
|
||||
self.W_Q = W_Q
|
||||
self.W_Q_scales = W_Q_scale
|
||||
|
||||
self.W_QR = W_QR
|
||||
self.W_QR_scales = W_QR_scale
|
||||
|
||||
# temp_Q_scale = self.W_Q_scales.contiguous()
|
||||
# temp_W_QR_scale = self.W_QR_scales.contiguous()
|
||||
# self.W_Q_QR = q_proj_weight.reshape(1536, -1).contiguous().transpose(-2,-1).contiguous().transpose(-2,-1)
|
||||
# self.W_Q_QR_scales = torch.concat([temp_Q_scale,temp_W_QR_scale],dim=1).contiguous().transpose(-2,-1).contiguous().transpose(-2,-1)
|
||||
#self.W_Q_QR = torch.concat([self.W_Q.contiguous(),self.W_QR.contiguous()],dim=1).contiguous().transpose(-2,-1).contiguous().transpose(-2,-1)
|
||||
#self.W_Q_QR_scales = torch.concat([W_Q_scale,W_QR_scale],dim=1).contiguous().transpose(-2,-1).contiguous().transpose(-2,-1)
|
||||
|
||||
self.W_UV = W_UV
|
||||
self.W_UV_scales = W_UV_scale
|
||||
|
||||
self.W_UK = W_UK
|
||||
self.W_UK_scales = W_UK_scale
|
||||
return
|
||||
|
||||
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
|
||||
assert kv_b_proj_weight.shape == (
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
|
||||
f"{kv_b_proj_weight.shape=}, "
|
||||
f"{self.kv_lora_rank=}, "
|
||||
f"{self.num_heads=}, "
|
||||
f"{self.qk_nope_head_dim=}, "
|
||||
f"{self.v_head_dim=}")
|
||||
kv_b_proj_weight = kv_b_proj_weight.view(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads,
|
||||
self.qk_nope_head_dim + self.v_head_dim,
|
||||
)
|
||||
|
||||
W_UK, W_UV = kv_b_proj_weight.split(
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
|
||||
q_proj_weight = get_and_maybe_dequant_weights(self.q_proj).T\
|
||||
.view(-1, self.num_heads, self.qk_head_dim)
|
||||
|
||||
# can be W_Q or W_UQ depending q_lora_rank, the former if
|
||||
# q_lora_rank is None, the latter otherwise. From the Attention backend
|
||||
# perspective though we call these both W_Q and rely on the layer
|
||||
# to pass in the correct matrix
|
||||
W_Q = q_proj_weight[..., :self.qk_nope_head_dim]
|
||||
self.W_QR = q_proj_weight[..., self.qk_nope_head_dim:]\
|
||||
.flatten(start_dim=1).contiguous()
|
||||
|
||||
# W_QR is small so for simplicity we dont bother requantizing it
|
||||
self.W_QR = self.W_QR.to(act_dtype)
|
||||
|
||||
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
|
||||
assert False, "please set VLLM_MLA_PERFORM_MATRIX_ABSORPTION=0"
|
||||
requantization_enabled = not envs.VLLM_MLA_DISABLE_REQUANTIZATION
|
||||
if is_fp8(weight_dtype) and requantization_enabled:
|
||||
# This assumes it wise to requantize using the same group shapes
|
||||
# (i.e. strategy, per-tensor, per-channel, block etc.) that the
|
||||
# weights were originally quantized
|
||||
requant_input_group_shape, requant_weight_group_shape = \
|
||||
get_scale_group_shapes_for_fp8(self.q_proj)
|
||||
assert (requant_input_group_shape, requant_weight_group_shape)\
|
||||
== get_scale_group_shapes_for_fp8(self.kv_b_proj)
|
||||
assert (requant_input_group_shape, requant_weight_group_shape)\
|
||||
== get_scale_group_shapes_for_fp8(self.o_proj)
|
||||
self.reqaunt_input_group_shape = requant_input_group_shape
|
||||
self.reqaunt_weight_group_shape = requant_weight_group_shape
|
||||
|
||||
#
|
||||
# Perform matrix-absorption following
|
||||
# https://github.com/flashinfer-ai/flashinfer/pull/551
|
||||
# for decode, as a result we end up with absorbed weights for decode
|
||||
# and another copy of raw weights for prefill.
|
||||
#
|
||||
self.W_UK, self.W_UV = kv_b_proj_weight.split(
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
# We absorb `W_UK` into `W_Q` resulting in either W_Q_UK or W_UQ_UK
|
||||
# depending q_lora_rank, the former if q_lora_rank is None, the
|
||||
# latter otherwise
|
||||
# basically if q_lora_rank is none we are absorbing into q_proj
|
||||
# instead of UQ
|
||||
W_Q_UK = torch.einsum("qnd,lnd -> qnl", W_Q, W_UK)\
|
||||
.flatten(start_dim=1).contiguous()
|
||||
|
||||
if is_fp8(weight_dtype) and requantization_enabled:
|
||||
W_Q_UK, W_Q_UK_scales = scaled_quantize(
|
||||
W_Q_UK,
|
||||
self.reqaunt_weight_group_shape,
|
||||
quant_dtype=current_platform_fp8_dtype)
|
||||
# For FP8 save the transpose so we can use
|
||||
# `apply_w8a8_block_fp8_linear` directly
|
||||
self.W_Q_UK = W_Q_UK.T.contiguous()
|
||||
self.W_Q_UK_scales = W_Q_UK_scales.T.contiguous()
|
||||
else:
|
||||
self.W_Q_UK = W_Q_UK.to(act_dtype)
|
||||
|
||||
W_O = get_and_maybe_dequant_weights(self.o_proj)\
|
||||
.view(-1, self.num_heads, self.v_head_dim)
|
||||
W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\
|
||||
.flatten(start_dim=0, end_dim=1).contiguous()
|
||||
|
||||
if is_fp8(weight_dtype) and requantization_enabled:
|
||||
W_UV_O, W_UV_O_scales = scaled_quantize(
|
||||
W_UV_O,
|
||||
self.reqaunt_weight_group_shape,
|
||||
quant_dtype=current_platform_fp8_dtype)
|
||||
# For FP8 save the transpose so we can use
|
||||
# `apply_w8a8_block_fp8_linear` directly
|
||||
self.W_UV_O = W_UV_O.T.contiguous()
|
||||
self.W_UV_O_scales = W_UV_O_scales.T.contiguous()
|
||||
else:
|
||||
self.W_UV_O = W_UV_O.to(act_dtype)
|
||||
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
else:
|
||||
# print('W_UV', W_UV.dtype) #float32
|
||||
#if is_fp8(weight_dtype):
|
||||
# raise NotImplementedError(
|
||||
# "Currently fp8 requires matrix absorption")
|
||||
# self.W_UV = W_UV
|
||||
# self.W_UK = W_UK
|
||||
self.W_UV = W_UV.to(act_dtype) # fp32 to bfp16
|
||||
self.W_UK = W_UK.to(act_dtype)
|
||||
W_Q = W_Q.to(act_dtype)
|
||||
self.W_Q = W_Q.flatten(start_dim=1)
|
||||
726
vllm_vacc/vllm/attention/backends/vacc_attn.py
Normal file
726
vllm_vacc/vllm/attention/backends/vacc_attn.py
Normal file
@@ -0,0 +1,726 @@
|
||||
""" Attention layer with torch scaled_dot_product_attention
|
||||
and PagedAttention."""
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType)
|
||||
# from vllm.attention.backends.utils import CommonAttentionState
|
||||
# from vllm.attention.ops.paged_attn import PagedAttentionMetadata
|
||||
|
||||
from vllm_vacc.vllm.attention.ops.vacc_paged_attn import VaccPagedAttention as PagedAttention
|
||||
|
||||
# from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
# AttentionLayer,
|
||||
# AttentionMetadata,
|
||||
# AttentionMetadataBuilder,
|
||||
# AttentionType)
|
||||
# from vllm.attention.backends.utils import CommonAttentionState
|
||||
# from vllm.attention.ops.ipex_attn import PagedAttention
|
||||
from vllm.attention.ops.paged_attn import PagedAttentionMetadata
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import make_tensor_with_pad
|
||||
from vllm_vacc.vllm.v1.worker.vacc_model_runner import ModelInputForVACCBuilder
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
import os
|
||||
from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size
|
||||
|
||||
class VACCAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TORCH_VACC"
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["VACCAttentionBackendImpl"]:
|
||||
return VACCAttentionBackendImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||
return VACCAttentionMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> Type["VACCMetadataBuilder"]:
|
||||
return VACCMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
|
||||
num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
PagedAttention.copy_blocks(kv_caches, src_to_dists)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VACCAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
"""Metadata for VACCAttentionMetadata.
|
||||
"""
|
||||
# Currently, input sequences can only contain all prompts
|
||||
# or all decoding. True if all sequences are prompts.
|
||||
chunked_prefill: bool
|
||||
seq_lens: Optional[List[int]] = None # For non-chunked prefill
|
||||
# For chunked prefill only
|
||||
max_query_len: Optional[int] = None
|
||||
max_kv_len: Optional[int] = None
|
||||
query_start_loc: Optional[torch.Tensor] = None
|
||||
kv_start_loc: Optional[torch.Tensor] = None
|
||||
prefill_block_tables: Optional[torch.Tensor] = None
|
||||
|
||||
# Begin encoder attn & enc/dec cross-attn fields...
|
||||
# Encoder sequence lengths representation
|
||||
encoder_seq_lens: Optional[List[int]] = None
|
||||
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
|
||||
|
||||
# Maximum sequence length among encoder sequences
|
||||
max_encoder_seq_len: Optional[int] = None
|
||||
|
||||
# Number of tokens input to encoder
|
||||
num_encoder_tokens: Optional[int] = None
|
||||
|
||||
# Cross-attention memory-mapping data structures: slot mapping
|
||||
# and block tables
|
||||
cross_slot_mapping: Optional[torch.Tensor] = None
|
||||
cross_block_tables: Optional[torch.Tensor] = None
|
||||
|
||||
def __post_init__(self):
|
||||
# Set during the execution of the first attention op.
|
||||
# It is a list because it is needed to set per prompt
|
||||
# when alibi slopes is used. It is because of the limitation
|
||||
# from xformer API.
|
||||
# will not appear in the __repr__ and __init__
|
||||
self.attn_bias: Optional[List[torch.Tensor]] = None
|
||||
self.encoder_attn_bias: Optional[List[torch.Tensor]] = None
|
||||
self.cross_attn_bias: Optional[List[torch.Tensor]] = None
|
||||
|
||||
@property
|
||||
def is_all_encoder_attn_metadata_set(self):
|
||||
'''
|
||||
All attention metadata required for encoder attention is set.
|
||||
'''
|
||||
return ((self.encoder_seq_lens is not None)
|
||||
and (self.encoder_seq_lens_tensor is not None)
|
||||
and (self.max_encoder_seq_len is not None))
|
||||
|
||||
@property
|
||||
def is_all_cross_attn_metadata_set(self):
|
||||
'''
|
||||
All attention metadata required for enc/dec cross-attention is set.
|
||||
|
||||
Superset of encoder attention required metadata.
|
||||
'''
|
||||
return (self.is_all_encoder_attn_metadata_set
|
||||
and (self.cross_slot_mapping is not None)
|
||||
and (self.cross_block_tables is not None))
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["VACCAttentionMetadata"]:
|
||||
# Currently chunked prefill is not supported
|
||||
if self.num_prefill_tokens == 0:
|
||||
return None
|
||||
return self
|
||||
|
||||
@property
|
||||
def decode_metadata(self) -> Optional["VACCAttentionMetadata"]:
|
||||
# Currently chunked prefill is not supported
|
||||
if self.num_decode_tokens == 0:
|
||||
return None
|
||||
return self
|
||||
|
||||
def get_seq_lens(
|
||||
self,
|
||||
attn_type: AttentionType,
|
||||
):
|
||||
'''
|
||||
Extract appropriate sequence lengths from attention metadata
|
||||
according to attention type.
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_metadata: Attention metadata structure associated with attention
|
||||
* attn_type: encoder attention, decoder self-attention,
|
||||
encoder/decoder cross-attention
|
||||
|
||||
Returns:
|
||||
* Appropriate sequence lengths tensor for query
|
||||
* Appropriate sequence lengths tensor for key & value
|
||||
'''
|
||||
|
||||
if (attn_type == AttentionType.DECODER
|
||||
or attn_type == AttentionType.ENCODER_ONLY):
|
||||
seq_lens_q = self.seq_lens
|
||||
seq_lens_kv = self.seq_lens
|
||||
elif attn_type == AttentionType.ENCODER:
|
||||
seq_lens_q = self.encoder_seq_lens
|
||||
seq_lens_kv = self.encoder_seq_lens
|
||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||
seq_lens_q = self.seq_lens
|
||||
seq_lens_kv = self.encoder_seq_lens
|
||||
else:
|
||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||
return seq_lens_q, seq_lens_kv
|
||||
|
||||
def get_attn_bias(
|
||||
self,
|
||||
attn_type: AttentionType,
|
||||
) -> Optional[List[torch.Tensor]]:
|
||||
'''
|
||||
Extract appropriate attention bias from attention metadata
|
||||
according to attention type.
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_metadata: Attention metadata structure associated with attention
|
||||
* attn_type: encoder attention, decoder self-attention,
|
||||
encoder/decoder cross-attention
|
||||
|
||||
Returns:
|
||||
* Appropriate attention bias value given the attention type
|
||||
'''
|
||||
|
||||
if (attn_type == AttentionType.DECODER
|
||||
or attn_type == AttentionType.ENCODER_ONLY):
|
||||
return self.attn_bias
|
||||
elif attn_type == AttentionType.ENCODER:
|
||||
return self.encoder_attn_bias
|
||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||
return self.cross_attn_bias
|
||||
else:
|
||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||
|
||||
def set_attn_bias(
|
||||
self,
|
||||
attn_bias: List[torch.Tensor],
|
||||
attn_type: AttentionType,
|
||||
) -> None:
|
||||
'''
|
||||
Update appropriate attention bias field of attention metadata,
|
||||
according to attention type.
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_metadata: Attention metadata structure associated with attention
|
||||
* attn_bias: The desired attention bias value
|
||||
* attn_type: encoder attention, decoder self-attention,
|
||||
encoder/decoder cross-attention
|
||||
'''
|
||||
|
||||
if (attn_type == AttentionType.DECODER
|
||||
or attn_type == AttentionType.ENCODER_ONLY):
|
||||
self.attn_bias = attn_bias
|
||||
elif attn_type == AttentionType.ENCODER:
|
||||
self.encoder_attn_bias = attn_bias
|
||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||
self.cross_attn_bias = attn_bias
|
||||
else:
|
||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||
|
||||
def get_seq_len_block_table_args(
|
||||
self,
|
||||
attn_type: str,
|
||||
) -> tuple:
|
||||
'''
|
||||
The particular choice of sequence-length- and block-table-related
|
||||
attributes which should be extracted from attn_metadata is dependent
|
||||
on the type of attention operation.
|
||||
|
||||
Decoder attn -> select entirely decoder self-attention-related fields
|
||||
Encoder/decoder cross-attn -> select encoder sequence lengths &
|
||||
cross-attn block-tables fields
|
||||
Encoder attn -> select encoder sequence lengths fields & no block tables
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_metadata: Attention metadata structure associated with attention
|
||||
* is_prompt: True if prefill, False otherwise
|
||||
* attn_type: encoder attention, decoder self-attention,
|
||||
encoder/decoder cross-attention
|
||||
|
||||
Returns:
|
||||
|
||||
* Appropriate sequence-lengths tensor
|
||||
* Appropriate max sequence-length scalar
|
||||
* Appropriate block tables (or None)
|
||||
'''
|
||||
|
||||
if (attn_type == AttentionType.DECODER
|
||||
or attn_type == AttentionType.ENCODER_ONLY):
|
||||
# Decoder self-attention
|
||||
# Choose max_seq_len based on whether we are in prompt_run
|
||||
return (self.seq_lens_tensor, self.max_decode_seq_len,
|
||||
self.block_tables)
|
||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||
# Enc/dec cross-attention KVs match encoder sequence length;
|
||||
# cross-attention utilizes special "cross" block tables
|
||||
return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len,
|
||||
self.cross_block_tables)
|
||||
elif attn_type == AttentionType.ENCODER:
|
||||
# No block tables associated with encoder attention
|
||||
return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len,
|
||||
None)
|
||||
else:
|
||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||
|
||||
class VACCMetadataBuilder(AttentionMetadataBuilder[VACCAttentionMetadata]):
|
||||
|
||||
def __init__(self, input_builder: ModelInputForVACCBuilder) -> None:
|
||||
self.chunked_prefill = input_builder.chunked_prefill
|
||||
self.input_builder = input_builder
|
||||
|
||||
def prepare(self):
|
||||
self.input_data = self.input_builder.input_data
|
||||
|
||||
def build(self, seq_lens: List[int], query_lens: List[int],
|
||||
cuda_graph_pad_size: int, batch_size: int) -> VACCAttentionMetadata:
|
||||
input_data = self.input_data
|
||||
prefill_seq_lens = seq_lens[0:input_data.num_prefills]
|
||||
prefill_query_lens = query_lens[0:input_data.num_prefills]
|
||||
slot_mapping = torch.tensor(input_data.slot_mapping,
|
||||
dtype=torch.int32,
|
||||
device=self.input_builder.device)
|
||||
|
||||
# For chunked-prefill
|
||||
if self.chunked_prefill and input_data.num_prefill_tokens != 0:
|
||||
prefill_block_tables = make_tensor_with_pad(
|
||||
self.input_data.prefill_block_tables,
|
||||
pad=0,
|
||||
dtype=torch.int32,
|
||||
device=self.input_builder.device,
|
||||
)
|
||||
query_lens_tensor = torch.tensor(prefill_query_lens,
|
||||
dtype=torch.int32,
|
||||
device=self.input_builder.device)
|
||||
kv_lens_tensor = torch.tensor(prefill_seq_lens,
|
||||
dtype=torch.int32,
|
||||
device=self.input_builder.device)
|
||||
query_start_loc = torch.zeros(input_data.num_prefills + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.input_builder.device)
|
||||
kv_start_loc = torch.zeros(input_data.num_prefills + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.input_builder.device)
|
||||
torch.cumsum(query_lens_tensor,
|
||||
dim=0,
|
||||
dtype=torch.int32,
|
||||
out=query_start_loc[1:])
|
||||
torch.cumsum(kv_lens_tensor,
|
||||
dim=0,
|
||||
dtype=torch.int32,
|
||||
out=kv_start_loc[1:])
|
||||
max_query_len = max(prefill_query_lens)
|
||||
max_kv_len = max(prefill_seq_lens)
|
||||
else:
|
||||
prefill_block_tables = None
|
||||
query_start_loc = None
|
||||
kv_start_loc = None
|
||||
max_query_len = None
|
||||
max_kv_len = None
|
||||
|
||||
# For paged attention
|
||||
if input_data.num_decode_tokens != 0:
|
||||
seq_lens_tensor = torch.tensor(
|
||||
input_data.seq_lens[input_data.num_prefills:],
|
||||
dtype=torch.int32,
|
||||
device=self.input_builder.device,
|
||||
)
|
||||
block_tables = make_tensor_with_pad(
|
||||
self.input_data.decode_block_tables,
|
||||
pad=0,
|
||||
dtype=torch.int32,
|
||||
device=self.input_builder.device,
|
||||
)
|
||||
# lowest_dim_size = block_tables.size(-1)
|
||||
# if lowest_dim_size < 1024:
|
||||
# padding_amount = 1024 - lowest_dim_size
|
||||
# padding = torch.zeros(*block_tables.size()[:-1], padding_amount, dtype=block_tables.dtype, device=block_tables.device)
|
||||
# block_tables = torch.cat((block_tables, padding), dim=-1)
|
||||
else:
|
||||
block_tables = torch.tensor([])
|
||||
seq_lens_tensor = torch.tensor(
|
||||
input_data.seq_lens[:input_data.num_prefills],
|
||||
dtype=torch.int32,
|
||||
device=self.input_builder.device,
|
||||
)
|
||||
|
||||
# For multi-modal models
|
||||
placeholder_index_maps = None
|
||||
if len(input_data.multi_modal_inputs_list) != 0:
|
||||
placeholder_index_maps = {
|
||||
modality: placeholder_map.index_map()
|
||||
for modality, placeholder_map in
|
||||
input_data.multi_modal_placeholder_maps.items()
|
||||
}
|
||||
|
||||
attn_metadata = VACCAttentionMetadata(
|
||||
chunked_prefill=self.chunked_prefill,
|
||||
seq_lens=seq_lens, #prefill_seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=max_query_len,
|
||||
max_kv_len=max_kv_len,
|
||||
query_start_loc=query_start_loc,
|
||||
kv_start_loc=kv_start_loc,
|
||||
max_decode_seq_len=None,
|
||||
num_prefills=input_data.num_prefills,
|
||||
num_prefill_tokens=input_data.num_prefill_tokens,
|
||||
num_decode_tokens=input_data.num_decode_tokens,
|
||||
block_tables=block_tables,
|
||||
prefill_block_tables=prefill_block_tables,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
||||
enable_kv_scales_calculation=False,
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
class VACCAttentionBackendImpl(AttentionImpl[VACCAttentionMetadata]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
) -> None:
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"Torch SPDA does not support block-sparse attention.")
|
||||
if logits_soft_cap is not None:
|
||||
logger.warning_once("Torch SPDA does not support logits soft cap. "
|
||||
"Outputs may be slightly off.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
self.sliding_window = sliding_window
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
self.need_mask = (self.alibi_slopes is not None
|
||||
or self.sliding_window is not None)
|
||||
|
||||
supported_head_sizes = PagedAttention.get_supported_head_sizes()
|
||||
if head_size not in supported_head_sizes:
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by PagedAttention. "
|
||||
f"Supported head sizes are: {supported_head_sizes}.")
|
||||
if kv_cache_dtype != "auto":
|
||||
raise NotImplementedError(
|
||||
"Torch SDPA backend does not support FP8 KV cache. "
|
||||
"Please use xFormers backend instead.")
|
||||
self.attn_type = attn_type
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: VACCAttentionMetadata, # type: ignore
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with torch SDPA and PagedAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
|
||||
NOTE: kv_cache will be an empty tensor with shape [0]
|
||||
for profiling run.
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
attn_type = self.attn_type
|
||||
if (attn_type == AttentionType.ENCODER
|
||||
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
|
||||
raise AttributeError("Encoder attention requires setting "
|
||||
"encoder metadata attributes.")
|
||||
elif (attn_type == AttentionType.ENCODER_DECODER
|
||||
and (not attn_metadata.is_all_cross_attn_metadata_set)):
|
||||
raise AttributeError("Encoder/decoder cross-attention "
|
||||
"requires setting cross-attention "
|
||||
"metadata attributes.")
|
||||
|
||||
# Reshape the query, key, and value tensors.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
if key is not None:
|
||||
assert value is not None
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
else:
|
||||
assert value is None
|
||||
if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0):
|
||||
# KV-cache during decoder-self- or
|
||||
# encoder-decoder-cross-attention, but not
|
||||
# during encoder attention.
|
||||
#
|
||||
# Even if there are no new key/value pairs to cache,
|
||||
# we still need to break out key_cache and value_cache
|
||||
# i.e. for later use by paged attention
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size)
|
||||
|
||||
|
||||
if (key is not None) and (value is not None):
|
||||
if attn_type == AttentionType.ENCODER_DECODER:
|
||||
# Update cross-attention KV cache (prefill-only)
|
||||
# During cross-attention decode, key & value will be None,
|
||||
# preventing this IF-statement branch from running
|
||||
updated_slot_mapping = attn_metadata.cross_slot_mapping
|
||||
else:
|
||||
# Update self-attention KV cache (prefill/decode)
|
||||
updated_slot_mapping = attn_metadata.slot_mapping
|
||||
PagedAttention.write_to_paged_cache(key, value, key_cache,
|
||||
value_cache,
|
||||
updated_slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale, layer._v_scale)
|
||||
|
||||
if attn_type != AttentionType.ENCODER:
|
||||
# Decoder self-attention supports chunked prefill.
|
||||
# Encoder/decoder cross-attention requires no chunked
|
||||
# prefill (100% prefill or 100% decode tokens, no mix)
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
else:
|
||||
# Encoder attention - chunked prefill is not applicable;
|
||||
# derive token-count from query shape & and treat them
|
||||
# as 100% prefill tokens
|
||||
assert attn_metadata.num_encoder_tokens is not None
|
||||
num_prefill_tokens = attn_metadata.num_encoder_tokens
|
||||
num_decode_tokens = 0
|
||||
|
||||
if attn_type == AttentionType.DECODER:
|
||||
# Only enforce this shape-constraint for decoder
|
||||
# self-attention
|
||||
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||
output = torch.empty_like(query)
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
assert attn_metadata.seq_lens is not None
|
||||
if (kv_cache.numel() == 0
|
||||
or prefill_meta.block_tables.numel() == 0):
|
||||
self._run_vacc_forward(
|
||||
output,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
prefill_meta,
|
||||
attn_type=attn_type)
|
||||
else:
|
||||
# prefix-enabled attention
|
||||
assert not self.need_mask
|
||||
import intel_extension_for_pytorch.llm.modules as ipex_modules
|
||||
output = torch.empty_like(query)
|
||||
ipex_modules.PagedAttention.flash_attn_varlen_func(
|
||||
output[:prefill_meta.num_prefill_tokens, :, :],
|
||||
query[:prefill_meta.num_prefill_tokens, :, :],
|
||||
key_cache,
|
||||
value_cache,
|
||||
prefill_meta.query_start_loc,
|
||||
prefill_meta.kv_start_loc,
|
||||
prefill_meta.max_query_len,
|
||||
prefill_meta.max_kv_len,
|
||||
self.scale,
|
||||
True,
|
||||
prefill_meta.prefill_block_tables,
|
||||
self.alibi_slopes,
|
||||
)
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
assert attn_type != AttentionType.ENCODER_ONLY, (
|
||||
"Encoder-only models should not have decode metadata.")
|
||||
# Decoding run.
|
||||
# (
|
||||
# seq_lens_arg,
|
||||
# max_seq_len_arg,
|
||||
# block_tables_arg,
|
||||
# ) = decode_meta.get_seq_len_block_table_args(attn_type)
|
||||
|
||||
# Note:
|
||||
# decode attention still use SDPA method
|
||||
# reshape k/v_cache to (num_block_grp, block_grp_size, head, hidden_size)
|
||||
k_cache = key_cache.view(-1, env_blk_grp_size, key_cache.shape[2], key_cache.shape[3])
|
||||
v_cache = value_cache.view(-1, env_blk_grp_size, value_cache.shape[2], value_cache.shape[3])
|
||||
block_per_group = env_blk_grp_size // 16
|
||||
# convert block_tables to 8K group index
|
||||
block_tables = (decode_meta.block_tables // block_per_group).to(torch.int32)
|
||||
attn_outs = []
|
||||
for i in range(decode_meta.seq_lens_tensor.shape[0]):
|
||||
seq_len = decode_meta.seq_lens_tensor[i]
|
||||
k_slices = k_cache[block_tables[i], ...]
|
||||
k = \
|
||||
torch.cat([k_slices[i, ...] for i in range(len(block_tables[i]))], dim=0)[:seq_len]
|
||||
v_slices = v_cache[block_tables[i], ...]
|
||||
v = \
|
||||
torch.cat([v_slices[i, ...] for i in range(len(block_tables[i]))], dim=0)[:seq_len]
|
||||
q = query[i : i + 1, ...]
|
||||
attn_out = torch.vacc.scaled_dot_product_attention(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
attn_mask=None,
|
||||
dropout_p=0,
|
||||
is_causal=False,
|
||||
is_train=False,
|
||||
recompute=False,
|
||||
flash_attention=False,
|
||||
sm_scale=self.scale,
|
||||
)
|
||||
attn_outs.append(attn_out)
|
||||
output = torch.cat(attn_outs, dim=0)
|
||||
# '''
|
||||
|
||||
# PagedAttention.forward_decode(
|
||||
# output[attn_metadata.num_prefill_tokens:, :, :],
|
||||
# query[attn_metadata.num_prefill_tokens:, :, :],
|
||||
# key_cache,
|
||||
# value_cache,
|
||||
# block_tables_arg,
|
||||
# seq_lens_arg,
|
||||
# max_seq_len_arg,
|
||||
# self.kv_cache_dtype,
|
||||
# self.num_kv_heads,
|
||||
# self.scale,
|
||||
# self.alibi_slopes,
|
||||
# layer._k_scale,
|
||||
# layer._v_scale,
|
||||
# )
|
||||
# Reshape the output tensor.
|
||||
return output.view(-1, self.num_heads * self.head_size)
|
||||
|
||||
def _run_vacc_forward(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_metadata: VACCAttentionMetadata,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
):
|
||||
# if self.num_kv_heads != self.num_heads:
|
||||
# key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
|
||||
# value = value.repeat_interleave(self.num_queries_per_kv, dim=1)
|
||||
attn_masks = attn_metadata.get_attn_bias(attn_type)
|
||||
if attn_masks is None:
|
||||
if self.alibi_slopes is not None:
|
||||
attn_masks = _make_alibi_bias(
|
||||
self.alibi_slopes, query.dtype,
|
||||
attn_metadata.seq_lens) # type: ignore
|
||||
elif self.sliding_window is not None:
|
||||
assert attn_metadata.seq_lens is not None
|
||||
attn_masks = _make_sliding_window_bias(
|
||||
attn_metadata.seq_lens, self.sliding_window,
|
||||
query.dtype) # type: ignore
|
||||
else:
|
||||
seq_lens, _ = attn_metadata.get_seq_lens(attn_type)
|
||||
attn_masks = [None] * len(seq_lens)
|
||||
attn_metadata.set_attn_bias(attn_masks, attn_type)
|
||||
|
||||
causal_attn = (attn_type == AttentionType.DECODER)
|
||||
|
||||
seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type)
|
||||
start_q, start_kv = 0, 0
|
||||
for seq_len_q, seq_len_kv, mask in zip(seq_lens_q, seq_lens_kv,
|
||||
attn_masks):
|
||||
end_q = start_q + seq_len_q
|
||||
end_kv = start_kv + seq_len_kv
|
||||
sub_out=torch.vacc.scaled_dot_product_attention(
|
||||
query[start_q:end_q,:, :],
|
||||
key[start_kv:end_kv,:, :],
|
||||
value[start_kv:end_kv,:, :].contiguous(),
|
||||
attn_mask=None,
|
||||
dropout_p=0.0,
|
||||
is_causal=True, #causal_attn and not self.need_mask,
|
||||
is_train=False,
|
||||
recompute=False,
|
||||
flash_attention=False,
|
||||
sm_scale=self.scale)
|
||||
output[ start_q:end_q,:, :] = sub_out
|
||||
start_q, start_kv = end_q, end_kv
|
||||
return output
|
||||
|
||||
|
||||
def _make_alibi_bias(
|
||||
alibi_slopes: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
seq_lens: List[int],
|
||||
) -> List[torch.Tensor]:
|
||||
attn_biases: List[torch.Tensor] = []
|
||||
for seq_len in seq_lens:
|
||||
bias = torch.arange(seq_len, dtype=dtype)
|
||||
# NOTE(zhuohan): HF uses
|
||||
# `bias = bias[None, :].repeat(seq_len, 1)`
|
||||
# here. We find that both biases give the same results, but
|
||||
# the bias below more accurately follows the original ALiBi
|
||||
# paper.
|
||||
bias = bias[None, :] - bias[:, None]
|
||||
|
||||
num_heads = alibi_slopes.shape[0]
|
||||
bias = bias[None, :].repeat((num_heads, 1, 1))
|
||||
bias.mul_(alibi_slopes[:, None, None]).unsqueeze_(0)
|
||||
inf_mask = torch.empty(
|
||||
(1, seq_len, seq_len),
|
||||
dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1)
|
||||
attn_biases.append((bias + inf_mask).to(dtype))
|
||||
|
||||
return attn_biases
|
||||
|
||||
|
||||
def _make_sliding_window_bias(
|
||||
seq_lens: List[int],
|
||||
window_size: Optional[int],
|
||||
dtype: torch.dtype,
|
||||
) -> List[torch.Tensor]:
|
||||
attn_biases: List[torch.Tensor] = []
|
||||
for seq_len in seq_lens:
|
||||
tensor = torch.full(
|
||||
(1, seq_len, seq_len),
|
||||
dtype=dtype,
|
||||
fill_value=1,
|
||||
)
|
||||
shift = 0
|
||||
mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore
|
||||
if window_size is not None:
|
||||
mask = torch.triu(mask, diagonal=shift - window_size + 1)
|
||||
mask = torch.log(mask)
|
||||
attn_biases.append(mask.to(dtype))
|
||||
|
||||
return attn_biases
|
||||
847
vllm_vacc/vllm/attention/backends/vacc_mla.py
Normal file
847
vllm_vacc/vllm/attention/backends/vacc_mla.py
Normal file
@@ -0,0 +1,847 @@
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from itertools import accumulate
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
from vllm.multimodal import MultiModalPlaceholderMap
|
||||
|
||||
try:
|
||||
from flashinfer import BatchDecodeMlaWithPagedKVCacheWrapper
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
|
||||
except ImportError:
|
||||
BatchDecodeMlaWithPagedKVCacheWrapper = None
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionState, AttentionType)
|
||||
from vllm.attention.backends.mla.common import MLACommonImpl, MLACommonMetadata
|
||||
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
|
||||
compute_slot_mapping_start_idx,
|
||||
is_block_tables_empty)
|
||||
|
||||
#from vllm.attention.ops.paged_attn import PagedAttention
|
||||
from vllm_vacc.vllm.attention.ops.vacc_paged_attn import VaccPagedAttention as PagedAttention
|
||||
# from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
|
||||
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||
|
||||
# import time, os
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm_vacc.vllm.worker.vacc_model_runner import (ModelInputForVACCBuilder,
|
||||
ModelInputForVACCWithSamplingMetadata)
|
||||
|
||||
|
||||
class VACCMLABackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TORCH_VACC"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["VACCMLAImpl"]:
|
||||
return VACCMLAImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||
return VACCMLAMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> Type["VACCMLAMetadataBuilder"]:
|
||||
return VACCMLAMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["VACCMLAState"]:
|
||||
return VACCMLAState
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int, # assumed to be 1 for MLA
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (num_blocks, block_size, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
PagedAttention.copy_blocks(kv_caches, src_to_dists)
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> List[int]:
|
||||
return [576]
|
||||
|
||||
|
||||
class VACCMLAState(AttentionState):
|
||||
|
||||
def __init__(self, runner):
|
||||
self.runner = runner
|
||||
self._is_graph_capturing = False
|
||||
|
||||
@contextmanager
|
||||
def graph_capture(self, max_batch_size: int):
|
||||
self._is_graph_capturing = True
|
||||
|
||||
self._graph_slot_mapping = torch.full((max_batch_size, ),
|
||||
PAD_SLOT_ID,
|
||||
dtype=torch.long,
|
||||
device=self.runner.device)
|
||||
self._graph_seq_lens = torch.ones(max_batch_size,
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
self._graph_block_tables = torch.from_numpy(
|
||||
self.runner.graph_block_tables).to(device=self.runner.device)
|
||||
|
||||
self._positions = torch.zeros((max_batch_size, ),
|
||||
dtype=torch.long,
|
||||
device=self.runner.device)
|
||||
|
||||
yield
|
||||
|
||||
self._is_graph_capturing = False
|
||||
del self._graph_slot_mapping
|
||||
del self._graph_seq_lens
|
||||
del self._graph_block_tables
|
||||
del self._positions
|
||||
|
||||
def graph_clone(self, batch_size: int):
|
||||
assert self._is_graph_capturing
|
||||
return self.__class__(self.runner)
|
||||
|
||||
def graph_capture_get_metadata_for_batch(
|
||||
self, batch_size: int, is_encoder_decoder_model: bool = False):
|
||||
assert self._is_graph_capturing
|
||||
|
||||
attn_metadata = self.runner.attn_backend.make_metadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=batch_size,
|
||||
slot_mapping=self._graph_slot_mapping[:batch_size],
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=True,
|
||||
seq_lens=None,
|
||||
seq_lens_tensor=self._graph_seq_lens[:batch_size],
|
||||
# max_query_len=1,
|
||||
# max_decode_query_len=1,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.runner.max_seq_len_to_capture,
|
||||
query_start_loc=None,
|
||||
seq_start_loc=None,
|
||||
context_lens_tensor=None,
|
||||
block_tables=self._graph_block_tables[:batch_size],
|
||||
use_cuda_graph=True,
|
||||
input_positions=self._positions[:batch_size],
|
||||
head_dim=self.runner.model_config.get_head_size())
|
||||
|
||||
if is_encoder_decoder_model:
|
||||
raise NotImplementedError(
|
||||
"VACCMLAState does not support encoder/decoder yet")
|
||||
|
||||
return attn_metadata
|
||||
|
||||
def get_graph_input_buffers(self,
|
||||
attn_metadata,
|
||||
is_encoder_decoder_model: bool = False):
|
||||
input_buffers = {
|
||||
"slot_mapping": attn_metadata.slot_mapping,
|
||||
"seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
|
||||
"block_tables": attn_metadata.decode_metadata.block_tables,
|
||||
"input_positions": attn_metadata.decode_metadata.input_positions,
|
||||
}
|
||||
if is_encoder_decoder_model:
|
||||
raise NotImplementedError(
|
||||
"VACCMLAState does not support encoder/decoder yet")
|
||||
|
||||
return input_buffers
|
||||
|
||||
def prepare_graph_input_buffers(self,
|
||||
input_buffers,
|
||||
attn_metadata,
|
||||
is_encoder_decoder_model: bool = False):
|
||||
input_positions = attn_metadata.input_positions
|
||||
num_positions = input_positions.shape[0]
|
||||
input_buffers["seq_lens_tensor"].copy_(
|
||||
attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
|
||||
input_buffers["block_tables"].copy_(
|
||||
attn_metadata.decode_metadata.block_tables, non_blocking=True)
|
||||
# CUDA graph buffer is padded so only perform a partial copy based on
|
||||
# num_positions
|
||||
input_buffers["input_positions"][:num_positions].copy_(
|
||||
input_positions, non_blocking=True)
|
||||
if is_encoder_decoder_model:
|
||||
raise NotImplementedError(
|
||||
"VACCMLAState does not support encoder/decoder yet")
|
||||
|
||||
def begin_forward(self, model_input):
|
||||
return
|
||||
|
||||
|
||||
@dataclass
|
||||
class VACCMLAMetadata(MLACommonMetadata):
|
||||
"""Metadata for VACCMLAMetadata.
|
||||
|
||||
NOTE: Any python object stored here is not updated when it is
|
||||
cuda-graph replayed. If you have values that need to be changed
|
||||
dynamically, it should be stored in tensor. The tensor has to be
|
||||
updated from `CUDAGraphRunner.forward` API.
|
||||
"""
|
||||
# (batch_size,). The sequence length per sequence. Sequence length means
|
||||
# the computed tokens + new tokens None if it is a decoding.
|
||||
seq_lens: Optional[List[int]]
|
||||
# seq_lens stored as a tensor.
|
||||
seq_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ---------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
# Maximum sequence length among prefill batch. 0 if there are decoding
|
||||
# requests only.
|
||||
max_prefill_seq_len: int
|
||||
# Maximum sequence length among decode batch. 0 if there are prefill
|
||||
# requests only.
|
||||
max_decode_seq_len: int
|
||||
# (batch_size,) A tensor of context lengths (tokens that are computed
|
||||
# so far).
|
||||
context_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# (batch_size, max_blocks_per_seq).
|
||||
# Block addresses per sequence. (Seq id -> list of physical block)
|
||||
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
|
||||
# in the kv cache. Each block can contain up to block_size tokens.
|
||||
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
|
||||
# captured.
|
||||
block_tables: Optional[torch.Tensor]
|
||||
|
||||
# Whether or not if cuda graph is enabled.
|
||||
# Cuda-graph is currently enabled for decoding only.
|
||||
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
||||
|
||||
use_cuda_graph: bool
|
||||
|
||||
# Maximum query length in the batch.
|
||||
max_query_len: Optional[int] = None
|
||||
input_positions: Optional[torch.Tensor] = None
|
||||
# Max number of query tokens among request in the batch.
|
||||
max_decode_query_len: Optional[int] = None
|
||||
|
||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||
# the batch, used to index into subquery. E.g., if the subquery length
|
||||
# is [4, 6], it is [0, 4, 10].
|
||||
query_start_loc: Optional[torch.Tensor] = None
|
||||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||
# [4, 6], it is [0, 4, 10].
|
||||
seq_start_loc: Optional[torch.Tensor] = None
|
||||
|
||||
_cached_prefill_metadata: Optional["VACCMLAMetadata"] = None
|
||||
_cached_decode_metadata: Optional["VACCMLAMetadata"] = None
|
||||
|
||||
num_prefill_tokens: int
|
||||
|
||||
num_kv_splits: int = 4 # TODO(lucas) add heuristic
|
||||
attn_logits: Optional[torch.Tensor] = None
|
||||
req_idx: Optional[torch.Tensor] = None
|
||||
|
||||
# The dimension of the attention heads
|
||||
head_dim: Optional[int] = None
|
||||
|
||||
def __post_init__(self):
|
||||
supported_head_sizes = VACCMLABackend.get_supported_head_sizes()
|
||||
if self.head_dim is not None and self.head_dim \
|
||||
not in supported_head_sizes:
|
||||
raise ValueError(
|
||||
f"Only {supported_head_sizes} are supported for head_dim,",
|
||||
f"received {self.head_dim}.")
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["VACCMLAMetadata"]:
|
||||
if self.num_prefills == 0:
|
||||
return None
|
||||
|
||||
if self._cached_prefill_metadata is not None:
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
assert self.seq_lens is not None
|
||||
assert self.seq_lens_tensor is not None
|
||||
|
||||
# Compute some attn_metadata fields which default to None
|
||||
query_start_loc = (None if self.query_start_loc is None else
|
||||
self.query_start_loc[:self.num_prefills + 1])
|
||||
slot_mapping = (None if self.slot_mapping is None else
|
||||
self.slot_mapping[:self.num_prefill_tokens])
|
||||
seq_lens = (None if self.seq_lens is None else
|
||||
self.seq_lens[:self.num_prefills])
|
||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||
self.seq_lens_tensor[:self.num_prefills])
|
||||
seq_start_loc = (None if self.seq_start_loc is None else
|
||||
self.seq_start_loc[:self.num_prefills + 1])
|
||||
context_lens_tensor = (None if self.context_lens_tensor is None else
|
||||
self.context_lens_tensor[:self.num_prefills])
|
||||
block_tables = (None if self.block_tables is None else
|
||||
self.block_tables[:self.num_prefills])
|
||||
input_positions = (None if self.input_positions is None else
|
||||
self.input_positions[:self.num_prefill_tokens])
|
||||
|
||||
self._cached_prefill_metadata = VACCMLAMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=0,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=self.
|
||||
multi_modal_placeholder_index_maps,
|
||||
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
|
||||
input_positions=input_positions,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_prefill_seq_len=None,
|
||||
max_decode_seq_len=0,
|
||||
query_start_loc=query_start_loc,
|
||||
seq_start_loc=seq_start_loc,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=False,
|
||||
head_dim=self.head_dim)
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
@property
|
||||
def decode_metadata(self) -> Optional["VACCMLAMetadata"]:
|
||||
if self.num_decode_tokens == 0:
|
||||
return None
|
||||
|
||||
if self._cached_decode_metadata is not None:
|
||||
return self._cached_decode_metadata
|
||||
assert self.seq_lens_tensor is not None
|
||||
|
||||
# Compute some attn_metadata fields which default to None
|
||||
slot_mapping = (None if self.slot_mapping is None else
|
||||
self.slot_mapping[self.num_prefill_tokens:])
|
||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||
self.seq_lens_tensor[self.num_prefills:])
|
||||
block_tables = (None if self.block_tables is None else
|
||||
self.block_tables[self.num_prefills:])
|
||||
input_positions = (None if self.input_positions is None else
|
||||
self.input_positions[self.num_prefill_tokens:])
|
||||
|
||||
self._cached_decode_metadata = VACCMLAMetadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=self.num_decode_tokens,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=True,
|
||||
seq_lens=self.seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_decode_query_len=self.max_decode_query_len,
|
||||
max_query_len=self.max_query_len,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.max_decode_seq_len,
|
||||
# Batch may be composed of prefill|decodes, adjust query start
|
||||
# indices to refer to the start of decodes. E.g.
|
||||
# in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
|
||||
query_start_loc=(self.query_start_loc[self.num_prefills:] -
|
||||
self.query_start_loc[self.num_prefills])
|
||||
if self.query_start_loc is not None else None,
|
||||
seq_start_loc=self.seq_start_loc[self.num_prefills:]
|
||||
if self.seq_start_loc is not None else None,
|
||||
context_lens_tensor=None,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=self.use_cuda_graph,
|
||||
input_positions=input_positions,
|
||||
head_dim=self.head_dim)
|
||||
return self._cached_decode_metadata
|
||||
|
||||
def advance_step(self,
|
||||
model_input: "ModelInputForVACCWithSamplingMetadata",
|
||||
sampled_token_ids: Optional[torch.Tensor],
|
||||
block_size: int,
|
||||
num_seqs: int,
|
||||
num_queries: int,
|
||||
turn_prefills_into_decodes: bool = False):
|
||||
"""
|
||||
Update metadata in-place to advance one decode step.
|
||||
"""
|
||||
# When using cudagraph, the num_seqs is padded to the next captured
|
||||
# batch sized, but num_queries tracks the actual number of requests in
|
||||
# the batch. For --enforce-eager mode, num_seqs == num_queries
|
||||
if num_seqs != num_queries:
|
||||
assert num_seqs > num_queries
|
||||
assert self.use_cuda_graph
|
||||
|
||||
if turn_prefills_into_decodes:
|
||||
# When Mutli-Step is enabled with Chunked-Prefill, prefills and
|
||||
# decodes are scheduled together. In the first step, all the
|
||||
# prefills turn into decodes. This update reflects that
|
||||
# conversion.
|
||||
assert self.num_decode_tokens + self.num_prefills == num_seqs
|
||||
self.num_decode_tokens += self.num_prefills
|
||||
self.num_prefills = 0
|
||||
# self.num_prefill_tokens = 0
|
||||
# self.max_prefill_seq_len = 0
|
||||
self.max_query_len = 1
|
||||
|
||||
self.slot_mapping = self.slot_mapping[:num_seqs]
|
||||
else:
|
||||
assert self.seq_lens is not None
|
||||
assert self.max_decode_seq_len == max(self.seq_lens)
|
||||
|
||||
assert self.num_prefills == 0
|
||||
assert self.num_prefill_tokens == 0
|
||||
assert self.num_decode_tokens == num_seqs
|
||||
assert self.slot_mapping.shape == (num_seqs, )
|
||||
|
||||
assert self.seq_lens is not None
|
||||
assert len(self.seq_lens) == num_seqs
|
||||
assert self.seq_lens_tensor is not None
|
||||
assert self.seq_lens_tensor.shape == (num_seqs, )
|
||||
# assert self.max_query_len == 1
|
||||
# assert self.max_prefill_seq_len == 0
|
||||
|
||||
assert self.query_start_loc is not None
|
||||
assert self.query_start_loc.shape == (num_queries + 1, )
|
||||
assert self.seq_start_loc is not None
|
||||
assert self.seq_start_loc.shape == (num_seqs + 1, )
|
||||
|
||||
assert self.context_lens_tensor is not None
|
||||
assert self.context_lens_tensor.shape == (num_queries, )
|
||||
|
||||
assert self.block_tables is not None
|
||||
assert self.block_tables.shape[0] == num_seqs
|
||||
|
||||
# Update query lengths. Note that we update only queries and not seqs,
|
||||
# since tensors may be padded due to captured cuda graph batch size
|
||||
for i in range(num_queries):
|
||||
self.seq_lens[i] += 1
|
||||
# self.max_decode_seq_len = None
|
||||
|
||||
ops.advance_step_flashattn(num_seqs=num_seqs,
|
||||
num_queries=num_queries,
|
||||
block_size=block_size,
|
||||
input_tokens=model_input.input_tokens,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
input_positions=model_input.input_positions,
|
||||
seq_lens=self.seq_lens_tensor,
|
||||
slot_mapping=self.slot_mapping,
|
||||
block_tables=self.block_tables)
|
||||
|
||||
|
||||
class VACCMLAMetadataBuilder(AttentionMetadataBuilder[VACCMLAMetadata]):
|
||||
|
||||
def __init__(self, input_builder: "ModelInputForVACCBuilder"):
|
||||
self.chunked_prefill = True
|
||||
if hasattr(input_builder, 'chunked_prefill'):
|
||||
self.chunked_prefill = input_builder.chunked_prefill
|
||||
self.input_builder = input_builder
|
||||
self.runner = input_builder.runner
|
||||
self.sliding_window = input_builder.sliding_window
|
||||
self.block_size = input_builder.block_size
|
||||
|
||||
def prepare(self):
|
||||
self.slot_mapping: List[int] = []
|
||||
self.prefill_seq_lens: List[int] = []
|
||||
self.context_lens: List[int] = []
|
||||
self.block_tables: List[List[int]] = []
|
||||
self.curr_seq_lens: List[int] = []
|
||||
self.input_positions: List[int] = []
|
||||
self.multimodal_placeholder_maps: Dict[
|
||||
str,
|
||||
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
|
||||
self.num_prefills = 0
|
||||
self.num_prefill_tokens = 0
|
||||
self.num_decode_tokens = 0
|
||||
self.has_prefix_cache_hit = False
|
||||
|
||||
def build(self, seq_lens: List[int], query_lens: List[int],
|
||||
cuda_graph_pad_size: int, batch_size: int):
|
||||
"""Build attention metadata with on-device tensors.
|
||||
|
||||
Args:
|
||||
seq_lens: The maybe padded sequence lengths of the input sequences.
|
||||
query_lens: The query lengths of the input sequences.
|
||||
cuda_graph_pad_size: The padding size for cuda graph.
|
||||
-1 if cuda graph is not used.
|
||||
batch_size: The maybe padded batch size.
|
||||
"""
|
||||
|
||||
self.input_data = self.input_builder.input_data
|
||||
|
||||
self.slot_mapping=self.input_data.slot_mapping
|
||||
self.context_lens= self.input_data.context_lens
|
||||
if self.input_data.num_prefill_tokens !=0:
|
||||
|
||||
self.block_tables = self.input_data.prefill_block_tables
|
||||
else:
|
||||
self.block_tables= self.input_data.decode_block_tables
|
||||
self.input_positions= self.input_data.input_positions
|
||||
|
||||
self.prefill_seq_lens = seq_lens[0:self.input_data.num_prefills]
|
||||
|
||||
self.num_prefills = self.input_data.num_prefills
|
||||
self.num_prefill_tokens = self.input_data.num_prefill_tokens
|
||||
self.num_decode_tokens = self.input_data.num_decode_tokens
|
||||
|
||||
device = self.runner.device
|
||||
use_captured_graph = cuda_graph_pad_size != -1
|
||||
|
||||
# max_query_len = max(query_lens)
|
||||
# decode_query_lens = query_lens[self.num_prefills:]
|
||||
# if len(decode_query_lens) > 0:
|
||||
# max_decode_query_len = max(decode_query_lens)
|
||||
# else:
|
||||
# max_decode_query_len = 1
|
||||
# max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
||||
# max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
||||
num_decode_tokens = self.num_decode_tokens
|
||||
query_start_loc = list(accumulate(query_lens, initial=0))
|
||||
seq_start_loc = list(accumulate(seq_lens, initial=0))
|
||||
|
||||
num_seqs = len(seq_lens)
|
||||
if use_captured_graph:
|
||||
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
|
||||
self.block_tables.extend([] * cuda_graph_pad_size)
|
||||
num_decode_tokens = batch_size - self.num_prefill_tokens
|
||||
block_tables = self._get_graph_runner_block_tables(
|
||||
num_seqs, self.block_tables)
|
||||
else:
|
||||
block_tables = make_tensor_with_pad(
|
||||
self.block_tables,
|
||||
pad=0,
|
||||
dtype=torch.int,
|
||||
device=device,
|
||||
)
|
||||
# assert max_query_len > 0, ("query_lens: {}".format(query_lens))
|
||||
|
||||
assert device is not None
|
||||
context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
|
||||
device, self.runner.pin_memory)
|
||||
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
|
||||
self.runner.pin_memory)
|
||||
input_positions = async_tensor_h2d(self.input_positions, torch.int,
|
||||
device, self.runner.pin_memory)
|
||||
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.int,
|
||||
device, self.runner.pin_memory)
|
||||
query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
|
||||
device,
|
||||
self.runner.pin_memory)
|
||||
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
|
||||
device, self.runner.pin_memory)
|
||||
placeholder_index_maps = {
|
||||
modality: placeholder_map.index_map()
|
||||
for modality, placeholder_map in
|
||||
self.multimodal_placeholder_maps.items()
|
||||
}
|
||||
return VACCMLAMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
slot_mapping=slot_mapping_tensor,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
seq_lens=seq_lens,
|
||||
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
||||
enable_kv_scales_calculation=True,
|
||||
input_positions=input_positions,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
# max_query_len=max_query_len,
|
||||
# max_decode_query_len=None,
|
||||
max_prefill_seq_len=None,
|
||||
max_decode_seq_len=None,
|
||||
query_start_loc=query_start_loc_tensor,
|
||||
seq_start_loc=seq_start_loc_tensor,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=use_captured_graph,
|
||||
num_kv_splits=4, # TODO(lucas) add heuristic
|
||||
head_dim=self.runner.model_config.get_head_size(),
|
||||
)
|
||||
|
||||
|
||||
class VACCMLAImpl(MLACommonImpl[VACCMLAMetadata]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
# MLA Specific Arguments
|
||||
**kwargs) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **kwargs)
|
||||
unsupported_features = [
|
||||
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
||||
]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"VACCMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, blocksparse_params, "
|
||||
"logits_soft_cap")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"VACCMLAImpl")
|
||||
|
||||
def extract_weights(self):
|
||||
weights = {}
|
||||
if hasattr(self, 'W_Q'):
|
||||
weights["W_Q"] = self.W_Q
|
||||
if hasattr(self, 'W_Q_scales'):
|
||||
weights["W_Q_scales"] = self.W_Q_scales
|
||||
if hasattr(self, 'W_QR'):
|
||||
weights['W_QR'] = self.W_QR
|
||||
if hasattr(self, 'W_QR_scales'):
|
||||
weights["W_QR_scales"] = self.W_QR_scales
|
||||
if hasattr(self, 'W_Q_QR'):
|
||||
weights["W_Q_QR"] = self.W_Q_QR
|
||||
if hasattr(self, 'W_Q_QR_scales'):
|
||||
weights["W_Q_QR_scales"] = self.W_Q_QR_scales
|
||||
if hasattr(self, 'W_UK'):
|
||||
weights['W_UK'] = self.W_UK
|
||||
if hasattr(self, 'W_UK_scales'):
|
||||
weights['W_UK_scales'] = self.W_UK_scales
|
||||
if hasattr(self, 'W_Q_UK_scales'):
|
||||
weights['W_Q_UK_scales'] = self.W_Q_UK_scales
|
||||
if hasattr(self, 'W_UV'):
|
||||
weights['W_UV'] = self.W_UV
|
||||
if hasattr(self, 'W_UV_scales'):
|
||||
weights['W_UV_scales'] = self.W_UV_scales
|
||||
if hasattr(self, 'W_UV_O'):
|
||||
weights['W_UV_O'] = self.W_UV_O
|
||||
if hasattr(self, 'W_UV_O_scales'):
|
||||
weights['W_UV_O_scales'] = self.W_UV_O_scales
|
||||
return weights
|
||||
|
||||
def _forward_prefill(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: VACCMLAMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert isinstance(attn_metadata, VACCMLAMetadata)
|
||||
kv_nope = self.kv_b_proj(kv_c_normed)[0]\
|
||||
.view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||
k_nope, v = kv_nope\
|
||||
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
|
||||
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
|
||||
v = v.contiguous()
|
||||
|
||||
# For MLA the v head dim is smaller than qk head dim so we pad out
|
||||
# v with 0s to match the qk head dim
|
||||
# v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
|
||||
# value=0)
|
||||
# attn_output = torch.vacc.scaled_dot_product_attention(
|
||||
# query=q,
|
||||
# key=k,
|
||||
# value=v_padded,
|
||||
# attn_mask=None,
|
||||
# dropout_p=0,
|
||||
# is_causal=True,
|
||||
# is_train=False,
|
||||
# recompute=False,
|
||||
# flash_attention=True,
|
||||
# sm_scale=self.scale
|
||||
# )
|
||||
|
||||
# attn_output = attn_output\
|
||||
# .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\
|
||||
# .reshape(-1, self.num_heads * v.shape[-1])
|
||||
seq_lens = attn_metadata.prefill_metadata.seq_lens
|
||||
if len(seq_lens) == 1:
|
||||
# Vacc supports different head dim of v and qk.
|
||||
attn_output = torch.vacc.scaled_dot_product_attention(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
attn_mask=None,
|
||||
dropout_p=0,
|
||||
is_causal=True,
|
||||
is_train=False,
|
||||
recompute=False,
|
||||
flash_attention=False,
|
||||
sm_scale=self.scale
|
||||
)
|
||||
attn_out = attn_output.view(-1, self.num_heads * v.shape[-1])
|
||||
else:
|
||||
attn_outs = []
|
||||
start = 0
|
||||
for seq in seq_lens:
|
||||
end = start + seq
|
||||
attn_out = torch.vacc.scaled_dot_product_attention(
|
||||
query=q[start:end, :],
|
||||
key=k[start:end, :],
|
||||
value=v[start:end, :],
|
||||
attn_mask=None,
|
||||
dropout_p=0,
|
||||
is_causal=True,
|
||||
is_train=False,
|
||||
recompute=False,
|
||||
flash_attention=False,
|
||||
sm_scale=self.scale
|
||||
)
|
||||
start = end
|
||||
attn_outs.append(attn_out)
|
||||
attn_out = torch.cat(attn_outs, dim=0).view(-1, self.num_heads * v.shape[-1])
|
||||
|
||||
return self.o_proj(attn_out)[0]
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: VACCMLAMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
raise NotImplementedError("FP8 Triton MLA not yet supported")
|
||||
|
||||
decode_meta = attn_metadata.decode_metadata
|
||||
assert decode_meta is not None
|
||||
B = q_nope.shape[0]
|
||||
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
o = torch.zeros(B,
|
||||
self.num_heads,
|
||||
self.kv_lora_rank,
|
||||
dtype=q.dtype,
|
||||
device=q.device)
|
||||
|
||||
# Add a head dim of 1
|
||||
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2)
|
||||
# print(f"kv_c_and_k_pe_cache: {kv_c_and_k_pe_cache.shape} ")
|
||||
kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]
|
||||
|
||||
# Run MQA using paged_attention
|
||||
# o = torch.vacc.paged_attention(
|
||||
# query=q,
|
||||
# key_cache=kv_c_and_k_pe_cache,
|
||||
# value_cache=kv_c_cache,
|
||||
# block_table=decode_meta.block_tables,
|
||||
# seq_len=decode_meta.seq_lens_tensor,
|
||||
# out=o,
|
||||
# sm_scale=self.scale
|
||||
# )
|
||||
|
||||
# Run MQA using spda
|
||||
# t0 = time.time()
|
||||
o = vacc_paged_attention_naive(
|
||||
q,
|
||||
kv_c_and_k_pe_cache,
|
||||
kv_c_cache,
|
||||
block_table = decode_meta.block_tables,
|
||||
# seq_lens = decode_meta.seq_lens_tensor,
|
||||
seq_lens=decode_meta.seq_lens,
|
||||
out = o,
|
||||
sm_scale=self.scale)
|
||||
# print(f'{os.getpid()} paged_atten(seq: {decode_meta.seq_lens}) time: {time.time() - t0}')
|
||||
|
||||
return self._v_up_proj_and_o_proj(o)
|
||||
|
||||
def vacc_paged_attention_naive(
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
block_table: torch.Tensor,
|
||||
# seq_lens: torch.Tensor,
|
||||
seq_lens: int,
|
||||
out: Optional[torch.Tensor] = None,
|
||||
sm_scale = -1
|
||||
) -> torch.Tensor:
|
||||
|
||||
# gurantee batch=1 perf
|
||||
if len(seq_lens) == 1:
|
||||
k = key_cache.view(-1, key_cache.shape[2], key_cache.shape[3])[:seq_lens[0]]
|
||||
v = value_cache.view(-1, value_cache.shape[2], value_cache.shape[3])[:seq_lens[0]]
|
||||
attn_out = torch.vacc.scaled_dot_product_attention(
|
||||
query=query,
|
||||
key=k,
|
||||
value=v,
|
||||
attn_mask=None,
|
||||
dropout_p=0,
|
||||
is_causal=False,
|
||||
is_train=False,
|
||||
recompute=False,
|
||||
flash_attention=False,
|
||||
sm_scale=sm_scale
|
||||
)
|
||||
else:
|
||||
# t0 = time.time()
|
||||
attn_outs = []
|
||||
for i in range(len(seq_lens)):
|
||||
k_slices = key_cache[block_table[i], :, :, :]
|
||||
k = torch.cat([k_slices[i, :, :, :].unsqueeze(1) for i in range(len(block_table[i]))], dim=0)
|
||||
k = k.view(-1, key_cache.shape[2], key_cache.shape[3])[:seq_lens[i]]
|
||||
v_slices = value_cache[block_table[i], :, :, :]
|
||||
v = torch.cat([v_slices[i, :, :, :].unsqueeze(1) for i in range(len(block_table[i]))], dim=0)
|
||||
v = v.view(-1, value_cache.shape[2], value_cache.shape[3])[:seq_lens[i]]
|
||||
|
||||
attn_out = torch.vacc.scaled_dot_product_attention(
|
||||
query=query[i:i+1,:,:],
|
||||
key=k,
|
||||
value=v,
|
||||
attn_mask=None,
|
||||
dropout_p=0,
|
||||
is_causal=False,
|
||||
is_train=False,
|
||||
recompute=False,
|
||||
flash_attention=False,
|
||||
sm_scale=sm_scale
|
||||
)
|
||||
attn_outs.append(attn_out)
|
||||
|
||||
attn_out = torch.cat(attn_outs, dim=0)
|
||||
# print(f'{os.getpid()} call spda(seq: {seq_lens}) time: {time.time() - t0}')
|
||||
return attn_out
|
||||
|
||||
# MLA single op impl
|
||||
def vacc_paged_attention_naive_singleop(
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
seq_lens,
|
||||
block_table = None,
|
||||
out: torch.Tensor = None,
|
||||
sm_scale = -1
|
||||
) -> torch.Tensor:
|
||||
k = key_cache.view(-1, key_cache.shape[2], key_cache.shape[3])[:seq_lens]
|
||||
v = value_cache.view(-1, value_cache.shape[2], value_cache.shape[3])[:seq_lens].squeeze(1)
|
||||
pe_cache = k[..., 512:].squeeze(1)
|
||||
print(f'q:{query[..., :512].shape} v:{v.shape} pe_cache:{pe_cache.shape}')
|
||||
q_nope_kv_c = torch.einsum("shc,tc->sht", query[..., :512], v)
|
||||
q_pe_k_pe = torch.einsum("shr,tr->sht", query[..., 512:], pe_cache)
|
||||
scores = (q_nope_kv_c + q_pe_k_pe) * sm_scale
|
||||
scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(query)
|
||||
o = torch.einsum("sht,tc->shc", scores, v)
|
||||
return o
|
||||
0
vllm_vacc/vllm/attention/ops/__init__.py
Normal file
0
vllm_vacc/vllm/attention/ops/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
160
vllm_vacc/vllm/attention/ops/vacc_paged_attn.py
Normal file
160
vllm_vacc/vllm/attention/ops/vacc_paged_attn.py
Normal file
@@ -0,0 +1,160 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
|
||||
if HAS_TRITON:
|
||||
from vllm.attention.ops.prefix_prefill import context_attention_fwd
|
||||
|
||||
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
|
||||
_PARTITION_SIZE = 512
|
||||
|
||||
@dataclass
|
||||
class PagedAttentionMetadata:
|
||||
"""Metadata for PagedAttention."""
|
||||
# (batch_size,). The length of sequences (entire tokens seen so far) per
|
||||
# sequence.
|
||||
seq_lens_tensor: Optional[torch.Tensor]
|
||||
# Maximum sequence length in the batch. 0 if it is prefill-only batch.
|
||||
max_decode_seq_len: int
|
||||
# (batch_size, max_blocks_per_seq).
|
||||
# Block addresses per sequence. (Seq id -> list of physical block)
|
||||
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
|
||||
# in the kv cache. Each block can contain up to block_size tokens.
|
||||
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
|
||||
# captured.
|
||||
block_tables: Optional[torch.Tensor]
|
||||
|
||||
class VaccPagedAttention:
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> List[int]:
|
||||
return [32, 64, 80, 96, 112, 120, 128, 192, 256]
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (2, num_blocks, block_size * num_kv_heads * head_size)
|
||||
|
||||
@staticmethod
|
||||
def split_kv_cache(
|
||||
kv_cache: torch.Tensor,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# x = 16 // kv_cache.element_size()
|
||||
num_blocks = kv_cache.shape[1]
|
||||
|
||||
key_cache = kv_cache[0]
|
||||
key_cache = key_cache.view(num_blocks, -1,num_kv_heads, head_size)
|
||||
value_cache = kv_cache[1]
|
||||
value_cache = value_cache.view(num_blocks, -1, num_kv_heads, head_size)
|
||||
return key_cache, value_cache
|
||||
|
||||
@staticmethod
|
||||
def write_to_paged_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
) -> None:
|
||||
# list_from_tensor = slot_mapping.tolist()
|
||||
torch.vacc.reshape_and_cache_attention(key,key_cache,slot_mapping)
|
||||
torch.vacc.reshape_and_cache_attention(value,value_cache,slot_mapping)
|
||||
|
||||
@staticmethod
|
||||
def forward_decode(
|
||||
output: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
max_seq_len: int,
|
||||
kv_cache_dtype: str,
|
||||
num_kv_heads: int,
|
||||
scale: float,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
tp_rank: int = 0,
|
||||
blocksparse_local_blocks: int = 0,
|
||||
blocksparse_vert_stride: int = 0,
|
||||
blocksparse_block_size: int = 64,
|
||||
blocksparse_head_sliding_step: int = 0,
|
||||
) -> torch.Tensor:
|
||||
torch.vacc.paged_attention(query,key_cache,value_cache,block_tables,seq_lens,-1,output)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def forward_prefix(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
query_start_loc: torch.Tensor,
|
||||
seq_lens_tensor: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
max_query_len: int,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
sliding_window: Optional[int],
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
) -> torch.Tensor:
|
||||
output = torch.empty_like(query)
|
||||
context_attention_fwd(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
kv_cache_dtype,
|
||||
key_cache,
|
||||
value_cache,
|
||||
block_tables,
|
||||
# query_start_loc is (batch_size + 1,)
|
||||
query_start_loc[:-1],
|
||||
seq_lens_tensor,
|
||||
context_lens,
|
||||
max_query_len,
|
||||
k_scale,
|
||||
v_scale,
|
||||
alibi_slopes,
|
||||
sliding_window,
|
||||
)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
src_key_cache = src_kv_cache[0]
|
||||
dst_key_cache = dst_kv_cache[0]
|
||||
torch.vacc.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
|
||||
|
||||
src_value_cache = src_kv_cache[1]
|
||||
dst_value_cache = dst_kv_cache[1]
|
||||
torch.vacc.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
key_caches = [kv_cache[0] for kv_cache in kv_caches]
|
||||
value_caches = [kv_cache[1] for kv_cache in kv_caches]
|
||||
torch.vacc.copy_blocks(key_caches, value_caches, src_to_dists)
|
||||
154
vllm_vacc/vllm/config.py
Normal file
154
vllm_vacc/vllm/config.py
Normal file
@@ -0,0 +1,154 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import ast
|
||||
import copy
|
||||
import enum
|
||||
import hashlib
|
||||
import inspect
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
import textwrap
|
||||
import warnings
|
||||
from collections import Counter
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import (MISSING, dataclass, field, fields, is_dataclass,
|
||||
replace)
|
||||
from importlib.util import find_spec
|
||||
from pathlib import Path
|
||||
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
|
||||
Optional, Protocol, TypeVar, Union, get_args)
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
from torch.distributed import ProcessGroup, ReduceOp
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS,
|
||||
QuantizationMethods,
|
||||
get_quantization_config)
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.platforms import CpuArchEnum, current_platform
|
||||
from vllm.config.model import _STR_DTYPE_TO_TORCH_DTYPE
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def ModelConfig___verify_quantization(self) -> None:
|
||||
supported_quantization = QUANTIZATION_METHODS
|
||||
optimized_quantization_methods = [
|
||||
"fp8", "modelopt", "gptq_marlin_24", "gptq_marlin",
|
||||
"awq_marlin", "fbgemm_fp8", "compressed-tensors", "experts_int8",
|
||||
"quark", "modelopt_fp4", "bitblas"#, "gptq_bitblas"
|
||||
]
|
||||
if self.quantization is not None:
|
||||
self.quantization = self.quantization.lower()
|
||||
|
||||
# Parse quantization method from the HF model config, if available.
|
||||
quant_cfg = self._parse_quant_hf_config(self.hf_config)
|
||||
if quant_cfg is None and (text_config := getattr(
|
||||
self.hf_config, "text_config", None)):
|
||||
# Check the text config as well for multi-modal models.
|
||||
quant_cfg = self._parse_quant_hf_config(text_config)
|
||||
|
||||
if quant_cfg is not None:
|
||||
quant_method = quant_cfg.get("quant_method", "").lower()
|
||||
quant_method = quant_method.replace("compressed_tensors",
|
||||
"compressed-tensors")
|
||||
quant_cfg["quant_method"] = quant_method
|
||||
|
||||
# Quantization methods which are overrides (i.e. they have a
|
||||
# `override_quantization_method` method) must be checked in order
|
||||
# of preference (this is particularly important for GPTQ).
|
||||
overrides = [
|
||||
# "marlin",
|
||||
"bitblas",
|
||||
"gptq_marlin_24",
|
||||
"gptq_marlin",
|
||||
# "gptq_bitblas",
|
||||
"awq_marlin",
|
||||
"ipex",
|
||||
"moe_wna16",
|
||||
"modelopt",
|
||||
"modelopt_fp4",
|
||||
"petit_nvfp4",
|
||||
]
|
||||
quantization_methods = [
|
||||
q for q in supported_quantization if q not in overrides
|
||||
]
|
||||
# Any custom overrides will be in quantization_methods so we place
|
||||
# them at the start of the list so custom overrides have preference
|
||||
# over the built in ones.
|
||||
quantization_methods = quantization_methods + overrides
|
||||
|
||||
# Detect which checkpoint is it
|
||||
for name in quantization_methods:
|
||||
method = get_quantization_config(name)
|
||||
quantization_override = method.override_quantization_method(
|
||||
quant_cfg, self.quantization)
|
||||
if quantization_override is not None:
|
||||
# Raise error if the override is not custom (custom would
|
||||
# be in QUANTIZATION_METHODS but not QuantizationMethods)
|
||||
# and hasn't been added to the overrides list.
|
||||
if (name in get_args(QuantizationMethods)
|
||||
and name not in overrides):
|
||||
raise ValueError(
|
||||
f"Quantization method {name} is an override but "
|
||||
"is has not been added to the `overrides` list "
|
||||
"above. This is necessary to ensure that the "
|
||||
"overrides are checked in order of preference.")
|
||||
quant_method = quantization_override
|
||||
self.quantization = quantization_override
|
||||
break
|
||||
|
||||
# Verify quantization configurations.
|
||||
if self.quantization is None:
|
||||
self.quantization = quant_method
|
||||
elif self.quantization != quant_method:
|
||||
raise ValueError(
|
||||
"Quantization method specified in the model config "
|
||||
f"({quant_method}) does not match the quantization "
|
||||
f"method specified in the `quantization` argument "
|
||||
f"({self.quantization}).")
|
||||
|
||||
if self.quantization is not None:
|
||||
if self.quantization not in supported_quantization:
|
||||
raise ValueError(
|
||||
f"Unknown quantization method: {self.quantization}. Must "
|
||||
f"be one of {supported_quantization}.")
|
||||
from vllm.platforms import current_platform
|
||||
current_platform.verify_quantization(self.quantization)
|
||||
if self.quantization not in optimized_quantization_methods:
|
||||
logger.warning(
|
||||
"%s quantization is not fully "
|
||||
"optimized yet. The speed can be slower than "
|
||||
"non-quantized models.", self.quantization)
|
||||
|
||||
|
||||
def _get_head_dtype(config: PretrainedConfig, dtype: torch.dtype,
|
||||
runner_type: str) -> torch.dtype:
|
||||
head_dtype: Optional[Union[str,
|
||||
torch.dtype]] = getattr(config, "head_dtype",
|
||||
None)
|
||||
|
||||
if head_dtype == "model":
|
||||
return dtype
|
||||
elif isinstance(head_dtype, str):
|
||||
head_dtype = head_dtype.lower()
|
||||
if head_dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
|
||||
raise ValueError(f"Unknown dtype: {head_dtype!r}")
|
||||
return _STR_DTYPE_TO_TORCH_DTYPE[head_dtype]
|
||||
elif isinstance(head_dtype, torch.dtype):
|
||||
return head_dtype
|
||||
elif head_dtype is None:
|
||||
if torch.float32 not in current_platform.supported_dtypes:
|
||||
return dtype
|
||||
if runner_type == "pooling":
|
||||
return torch.float16
|
||||
return dtype
|
||||
else:
|
||||
raise ValueError(f"Unknown dtype: {head_dtype}")
|
||||
52
vllm_vacc/vllm/config_manager.py
Normal file
52
vllm_vacc/vllm/config_manager.py
Normal file
@@ -0,0 +1,52 @@
|
||||
|
||||
#####################################################
|
||||
## 1. use for Memory-Recycler
|
||||
## .model_infos
|
||||
## ['deepseek_mtp',]
|
||||
##
|
||||
## 2. waitting...
|
||||
######################################################
|
||||
|
||||
import os
|
||||
class ConfigManager():
|
||||
def __init__(self):
|
||||
self._config_name = ".model_infos"
|
||||
|
||||
def update_model_infos(self, model_infos : str):
|
||||
from pathlib import Path
|
||||
workspace_path = Path.cwd()
|
||||
|
||||
bootinfo_config = f'{workspace_path}/{self._config_name}'
|
||||
try:
|
||||
with open(bootinfo_config, 'w') as w:
|
||||
w.write(model_infos)
|
||||
except Exception as e:
|
||||
print("[WARN] write model_infos fail, caused by ", e)
|
||||
raise False
|
||||
|
||||
def get_model_infos(self):
|
||||
from pathlib import Path
|
||||
workspace_path = Path.cwd()
|
||||
|
||||
bootinfo_config = f'{workspace_path}/{self._config_name}'
|
||||
bootinfo_inited = os.path.exists(bootinfo_config)
|
||||
|
||||
runner_model_infos = "default"
|
||||
if bootinfo_inited:
|
||||
try:
|
||||
with open(bootinfo_config) as w:
|
||||
runner_model_infos = w.readline()
|
||||
except Exception as e:
|
||||
print("[WARN] model_infos load fail ", e)
|
||||
|
||||
return runner_model_infos
|
||||
|
||||
config_manager = None
|
||||
|
||||
def vllm_vacc_config_manager():
|
||||
global config_manager
|
||||
|
||||
if config_manager is None:
|
||||
config_manager = ConfigManager()
|
||||
return config_manager
|
||||
|
||||
0
vllm_vacc/vllm/core/__init__.py
Normal file
0
vllm_vacc/vllm/core/__init__.py
Normal file
BIN
vllm_vacc/vllm/core/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/core/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_vacc/vllm/core/__pycache__/block_manager.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/core/__pycache__/block_manager.cpython-312.pyc
Normal file
Binary file not shown.
0
vllm_vacc/vllm/core/block/__init__.py
Normal file
0
vllm_vacc/vllm/core/block/__init__.py
Normal file
BIN
vllm_vacc/vllm/core/block/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/core/block/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
vllm_vacc/vllm/core/block/__pycache__/common.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/core/block/__pycache__/common.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
407
vllm_vacc/vllm/core/block/block_table.py
Normal file
407
vllm_vacc/vllm/core/block/block_table.py
Normal file
@@ -0,0 +1,407 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import math
|
||||
from typing import List, Optional
|
||||
|
||||
from vllm.core.block.common import BlockList
|
||||
from vllm.core.block.interfaces import Block, DeviceAwareBlockAllocator
|
||||
from vllm.utils import Device, cdiv, chunk_list
|
||||
|
||||
|
||||
class BlockTable:
|
||||
"""A class to manage blocks for a specific sequence.
|
||||
|
||||
The BlockTable maps a sequence of tokens to a list of blocks, where each
|
||||
block represents a contiguous memory allocation for a portion of the
|
||||
sequence. The blocks are managed by a DeviceAwareBlockAllocator, which is
|
||||
responsible for allocating and freeing memory for the blocks.
|
||||
|
||||
Args:
|
||||
block_size (int): The maximum number of tokens that can be stored in a
|
||||
single block.
|
||||
block_allocator (DeviceAwareBlockAllocator): The block allocator used to
|
||||
manage memory for the blocks.
|
||||
_blocks (Optional[List[Block]], optional): An optional list of existing
|
||||
blocks to initialize the BlockTable with. If not provided, an empty
|
||||
BlockTable is created.
|
||||
max_block_sliding_window (Optional[int], optional): The number of
|
||||
blocks to keep around for each sequence. If None, all blocks
|
||||
are kept (eg., when sliding window is not used).
|
||||
It should at least fit the sliding window size of the model.
|
||||
|
||||
Attributes:
|
||||
_block_size (int): The maximum number of tokens that can be stored in a
|
||||
single block.
|
||||
_allocator (DeviceAwareBlockAllocator): The block allocator used to
|
||||
manage memory for the blocks.
|
||||
_blocks (Optional[List[Block]]): The list of blocks managed by this
|
||||
BlockTable.
|
||||
_num_full_slots (int): The number of tokens currently stored in the
|
||||
blocks.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_size: int,
|
||||
block_allocator: DeviceAwareBlockAllocator,
|
||||
_blocks: Optional[List[Block]] = None,
|
||||
max_block_sliding_window: Optional[int] = None,
|
||||
):
|
||||
self._block_size = block_size
|
||||
self._allocator = block_allocator
|
||||
if _blocks is None:
|
||||
_blocks = []
|
||||
self._blocks: BlockList = BlockList(_blocks)
|
||||
|
||||
self._max_block_sliding_window = max_block_sliding_window
|
||||
self._num_full_slots = self._get_num_token_ids()
|
||||
|
||||
@staticmethod
|
||||
def get_num_required_blocks(token_ids: List[int],
|
||||
block_size: int,
|
||||
num_lookahead_slots: int = 0) -> int:
|
||||
"""Calculates the minimum number of blocks required to store a given
|
||||
sequence of token IDs along with any look-ahead slots that may be
|
||||
required (like in multi-step + chunked-prefill).
|
||||
|
||||
This assumes worst-case scenario, where every block requires a new
|
||||
allocation (e.g. ignoring prefix caching).
|
||||
|
||||
Args:
|
||||
token_ids (List[int]): The sequence of token IDs to be stored.
|
||||
block_size (int): The maximum number of tokens that can be stored in
|
||||
a single block.
|
||||
num_lookahead_slots (int): look-ahead slots that the sequence may
|
||||
require.
|
||||
|
||||
Returns:
|
||||
int: The minimum number of blocks required to store the given
|
||||
sequence of token IDs along with any required look-ahead slots.
|
||||
"""
|
||||
return cdiv(len(token_ids) + num_lookahead_slots, block_size)
|
||||
|
||||
def allocate(self,
|
||||
token_ids: List[int],
|
||||
device: Device = Device.GPU,
|
||||
extra_hash: Optional[int] = None,
|
||||
seq_id: Optional[int] = None) -> None:
|
||||
"""Allocates memory blocks for storing the given sequence of token IDs.
|
||||
|
||||
This method allocates the required number of blocks to store the given
|
||||
sequence of token IDs.
|
||||
|
||||
Args:
|
||||
token_ids (List[int]): The sequence of token IDs to be stored.
|
||||
device (Device, optional): The device on which the blocks should be
|
||||
allocated. Defaults to Device.GPU.
|
||||
extra_hash (Optional[int]): The hash value of additional
|
||||
factors, such as adapters, that influence the block hash
|
||||
in the prefixcaching block.
|
||||
"""
|
||||
assert not self._is_allocated
|
||||
assert token_ids
|
||||
blocks = self._allocate_blocks_for_token_ids(prev_block=None,
|
||||
token_ids=token_ids,
|
||||
device=device,
|
||||
extra_hash=extra_hash,
|
||||
seq_id=seq_id)
|
||||
self.update(blocks)
|
||||
self._num_full_slots = len(token_ids)
|
||||
|
||||
def update(self, blocks: List[Block]) -> None:
|
||||
"""Resets the table to the newly provided blocks
|
||||
(with their corresponding block ids)
|
||||
"""
|
||||
self._blocks.update(blocks)
|
||||
|
||||
def append_token_ids(self,
|
||||
token_ids: List[int],
|
||||
num_lookahead_slots: int = 0,
|
||||
num_computed_slots: Optional[int] = None,
|
||||
extra_hash: Optional[int] = None,
|
||||
seq_id: Optional[int] = None) -> None:
|
||||
"""Appends a sequence of token IDs to the existing blocks in the
|
||||
BlockTable.
|
||||
|
||||
This method appends the given sequence of token IDs to the existing
|
||||
blocks in the BlockTable. If there is not enough space in the existing
|
||||
blocks, new blocks are allocated using the `ensure_num_empty_slots`
|
||||
method to accommodate the additional tokens.
|
||||
|
||||
The token IDs are divided into chunks of size `block_size` (except for
|
||||
the first chunk, which may be smaller), and each chunk is appended to a
|
||||
separate block.
|
||||
|
||||
Args:
|
||||
token_ids (List[int]): The sequence of token IDs to be appended.
|
||||
num_computed_slots (Optional[int]): The number of KV cache slots
|
||||
that are already filled (computed).
|
||||
When sliding window is enabled, this is used to compute how many
|
||||
blocks to drop at the front of the sequence.
|
||||
Without sliding window, None can be passed.
|
||||
Without chunked prefill, it should be the same as
|
||||
_num_full_slots.
|
||||
extra_hash (Optional[int]): The hash value of additional
|
||||
factors such as adapters that influence the block, apart
|
||||
from the token_ids.
|
||||
"""
|
||||
assert self._is_allocated, "no blocks have been allocated"
|
||||
assert len(self._blocks) > 0
|
||||
|
||||
# Drop blocks that are no longer needed due to sliding window
|
||||
if self._max_block_sliding_window is not None:
|
||||
null_block = self._allocator.allocate_or_get_null_block()
|
||||
assert num_computed_slots is not None
|
||||
end_block_idx = (num_computed_slots //
|
||||
self._block_size) - self._max_block_sliding_window
|
||||
for idx in range(0, end_block_idx):
|
||||
b = self._blocks[idx]
|
||||
if b is not null_block:
|
||||
self._allocator.free(b)
|
||||
self._blocks[idx] = null_block
|
||||
|
||||
# Ensure there are enough empty slots for the new tokens plus
|
||||
# lookahead slots
|
||||
self.ensure_num_empty_slots(num_empty_slots=len(token_ids) +
|
||||
num_lookahead_slots,
|
||||
extra_hash=extra_hash,
|
||||
seq_id=seq_id)
|
||||
|
||||
# Update the blocks with the new tokens
|
||||
first_block_idx = self._num_full_slots // self._block_size
|
||||
token_blocks = self._chunk_token_blocks_for_append(token_ids)
|
||||
|
||||
for i, token_block in enumerate(token_blocks):
|
||||
self._blocks.append_token_ids(first_block_idx + i, token_block, seq_id=seq_id)
|
||||
|
||||
self._num_full_slots += len(token_ids)
|
||||
|
||||
def ensure_num_empty_slots(self,
|
||||
num_empty_slots: int,
|
||||
extra_hash: Optional[int] = None,
|
||||
seq_id: Optional[int] = None) -> None:
|
||||
"""Ensures that the BlockTable has at least the specified number of
|
||||
empty slots available.
|
||||
|
||||
This method checks if the BlockTable has enough empty slots (i.e.,
|
||||
available space) to accommodate the requested number of tokens. If not,
|
||||
it allocates additional blocks on the GPU to ensure that the required
|
||||
number of empty slots is available.
|
||||
|
||||
Args:
|
||||
num_empty_slots (int): The minimum number of empty slots required.
|
||||
extra_hash (Optional[int]): The hash value of additional
|
||||
factors such as adapters that influence the block, apart
|
||||
from the token_ids.
|
||||
"""
|
||||
# Currently the block table only supports
|
||||
# appending tokens to GPU blocks.
|
||||
device = Device.GPU
|
||||
assert self._is_allocated
|
||||
|
||||
if self._num_empty_slots >= num_empty_slots:
|
||||
return
|
||||
|
||||
slots_to_allocate = num_empty_slots - self._num_empty_slots
|
||||
blocks_to_allocate = cdiv(slots_to_allocate, self._block_size)
|
||||
|
||||
for _ in range(blocks_to_allocate):
|
||||
assert len(self._blocks) > 0
|
||||
self._blocks.append(
|
||||
self._allocator.allocate_mutable_block(
|
||||
prev_block=self._blocks[-1],
|
||||
device=device,
|
||||
extra_hash=extra_hash,
|
||||
seq_id=seq_id))
|
||||
|
||||
def fork(self) -> "BlockTable":
|
||||
"""Creates a new BlockTable instance with a copy of the blocks from the
|
||||
current instance.
|
||||
|
||||
This method creates a new BlockTable instance with the same block size,
|
||||
block allocator, and a copy of the blocks from the current instance. The
|
||||
new BlockTable has its own independent set of blocks, but shares the
|
||||
same underlying memory allocation with the original BlockTable.
|
||||
|
||||
Returns:
|
||||
BlockTable: A new BlockTable instance with a copy of the blocks from
|
||||
the current instance.
|
||||
"""
|
||||
assert self._is_allocated
|
||||
assert len(self._blocks) > 0
|
||||
forked_blocks = self._allocator.fork(self._blocks[-1])
|
||||
return BlockTable(
|
||||
block_size=self._block_size,
|
||||
block_allocator=self._allocator,
|
||||
_blocks=forked_blocks,
|
||||
max_block_sliding_window=self._max_block_sliding_window,
|
||||
)
|
||||
|
||||
def free(self, seq_id: Optional[int] = None) -> None:
|
||||
"""Frees the memory occupied by the blocks in the BlockTable.
|
||||
|
||||
This method iterates over all the blocks in the `_blocks` list and calls
|
||||
the `free` method of the `_allocator` object to release the memory
|
||||
occupied by each block. After freeing all the blocks, the `_blocks` list
|
||||
is set to `None`.
|
||||
"""
|
||||
self.blocks.reverse()
|
||||
for block in self.blocks:
|
||||
self._allocator.free(block, seq_id=seq_id)
|
||||
self._blocks.reset()
|
||||
|
||||
@property
|
||||
def physical_block_ids(self) -> List[int]:
|
||||
"""Returns a list of physical block indices for the blocks in the
|
||||
BlockTable.
|
||||
|
||||
This property returns a list of integers, where each integer represents
|
||||
the physical block index of a corresponding block in the `_blocks` list.
|
||||
The physical block index is a unique identifier for the memory location
|
||||
occupied by the block.
|
||||
|
||||
Returns:
|
||||
List[int]: A list of physical block indices for the blocks in the
|
||||
BlockTable.
|
||||
"""
|
||||
return self._blocks.ids()
|
||||
|
||||
def get_unseen_token_ids(self, sequence_token_ids: List[int]) -> List[int]:
|
||||
"""Get the number of "unseen" tokens in the sequence.
|
||||
|
||||
Unseen tokens are tokens in the sequence corresponding to this block
|
||||
table, but are not yet appended to this block table.
|
||||
|
||||
Args:
|
||||
sequence_token_ids (List[int]): The list of token ids in the
|
||||
sequence.
|
||||
|
||||
Returns:
|
||||
List[int]: The postfix of sequence_token_ids that has not yet been
|
||||
appended to the block table.
|
||||
"""
|
||||
|
||||
# Since the block table is append-only, the unseen token ids are the
|
||||
# ones after the appended ones.
|
||||
return sequence_token_ids[self.num_full_slots:]
|
||||
|
||||
def _allocate_blocks_for_token_ids(
|
||||
self,
|
||||
prev_block: Optional[Block],
|
||||
token_ids: List[int],
|
||||
device: Device,
|
||||
extra_hash: Optional[int] = None,
|
||||
seq_id: Optional[int] = None) -> List[Block]:
|
||||
blocks: List[Block] = []
|
||||
|
||||
block_token_ids = []
|
||||
tail_token_ids = []
|
||||
for cur_token_ids in chunk_list(token_ids, self._block_size):
|
||||
if len(cur_token_ids) == self._block_size:
|
||||
block_token_ids.append(cur_token_ids)
|
||||
else:
|
||||
tail_token_ids.append(cur_token_ids)
|
||||
|
||||
if block_token_ids:
|
||||
blocks.extend(
|
||||
self._allocator.allocate_immutable_blocks(
|
||||
prev_block,
|
||||
block_token_ids=block_token_ids,
|
||||
device=device,
|
||||
extra_hash=extra_hash,
|
||||
seq_id=seq_id))
|
||||
prev_block = blocks[-1]
|
||||
|
||||
if tail_token_ids:
|
||||
assert len(tail_token_ids) == 1
|
||||
cur_token_ids = tail_token_ids[0]
|
||||
|
||||
block = self._allocator.allocate_mutable_block(
|
||||
prev_block=prev_block, device=device, extra_hash=extra_hash, seq_id=seq_id)
|
||||
block.append_token_ids(cur_token_ids, seq_id)
|
||||
|
||||
blocks.append(block)
|
||||
|
||||
return blocks
|
||||
|
||||
def _get_all_token_ids(self) -> List[int]:
|
||||
# NOTE: This function is O(seq_len); use sparingly.
|
||||
token_ids: List[int] = []
|
||||
|
||||
if not self._is_allocated:
|
||||
return token_ids
|
||||
|
||||
for block in self.blocks:
|
||||
token_ids.extend(block.token_ids)
|
||||
|
||||
return token_ids
|
||||
|
||||
def _get_num_token_ids(self) -> int:
|
||||
res = 0
|
||||
for block in self.blocks:
|
||||
res += len(block.token_ids)
|
||||
|
||||
return res
|
||||
|
||||
@property
|
||||
def _is_allocated(self) -> bool:
|
||||
return len(self._blocks) > 0
|
||||
|
||||
@property
|
||||
def blocks(self) -> List[Block]:
|
||||
return self._blocks.list()
|
||||
|
||||
@property
|
||||
def _num_empty_slots(self) -> int:
|
||||
assert self._is_allocated
|
||||
return len(self._blocks) * self._block_size - self._num_full_slots
|
||||
|
||||
@property
|
||||
def num_full_slots(self) -> int:
|
||||
"""Returns the total number of tokens currently stored in the
|
||||
BlockTable.
|
||||
|
||||
Returns:
|
||||
int: The total number of tokens currently stored in the BlockTable.
|
||||
"""
|
||||
return self._num_full_slots
|
||||
|
||||
def get_num_blocks_touched_by_append_slots(
|
||||
self, token_ids: List[int], num_lookahead_slots: int) -> int:
|
||||
"""Determine how many blocks will be "touched" by appending the token
|
||||
ids.
|
||||
|
||||
This is required for the scheduler to determine whether a sequence can
|
||||
continue generation, or if it must be preempted.
|
||||
"""
|
||||
# Math below is equivalent to:
|
||||
# all_token_ids = token_ids + [-1] * num_lookahead_slots
|
||||
# token_blocks = self._chunk_token_blocks_for_append(all_token_ids)
|
||||
# return len(token_blocks)
|
||||
|
||||
num_token_ids = len(token_ids) + num_lookahead_slots
|
||||
first_chunk_size = self._block_size - (self._num_full_slots %
|
||||
self._block_size)
|
||||
num_token_blocks = (1 + math.ceil(
|
||||
(num_token_ids - first_chunk_size) / self._block_size))
|
||||
return num_token_blocks
|
||||
|
||||
def _chunk_token_blocks_for_append(
|
||||
self, token_ids: List[int]) -> List[List[int]]:
|
||||
"""Split the token ids into block-sized chunks so they can be easily
|
||||
appended to blocks. The first such "token block" may have less token ids
|
||||
than the block size, since the last allocated block may be partially
|
||||
full.
|
||||
|
||||
If no token ids are provided, then no chunks are returned.
|
||||
"""
|
||||
|
||||
if not token_ids:
|
||||
return []
|
||||
|
||||
first_chunk_size = self._block_size - (self._num_full_slots %
|
||||
self._block_size)
|
||||
token_blocks = [token_ids[:first_chunk_size]]
|
||||
token_blocks.extend(
|
||||
chunk_list(token_ids[first_chunk_size:], self._block_size))
|
||||
return token_blocks
|
||||
21
vllm_vacc/vllm/core/block/common.py
Normal file
21
vllm_vacc/vllm/core/block/common.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from typing import Deque, Dict, Iterable, List, Optional, Protocol, Tuple
|
||||
|
||||
from vllm.core.block.interfaces import Block, BlockAllocator
|
||||
|
||||
BlockId = int
|
||||
RefCount = int
|
||||
|
||||
class BlockList:
|
||||
def append_token_ids(self, block_index: int, token_ids: List[int], seq_id: Optional[int]=None) -> None:
|
||||
block = self._blocks[block_index]
|
||||
prev_block_id = block.block_id
|
||||
|
||||
block.append_token_ids(token_ids, seq_id=seq_id)
|
||||
|
||||
# CoW or promotion may update the internal block_id
|
||||
if prev_block_id != block.block_id:
|
||||
self._update_block_id(block_index, block.block_id)
|
||||
373
vllm_vacc/vllm/core/block/cpu_gpu_block_allocator.py
Normal file
373
vllm_vacc/vllm/core/block/cpu_gpu_block_allocator.py
Normal file
@@ -0,0 +1,373 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Dict, FrozenSet, List, Optional, Tuple
|
||||
|
||||
from vllm.core.block.interfaces import (Block, BlockAllocator, BlockId,
|
||||
DeviceAwareBlockAllocator)
|
||||
# from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator
|
||||
from vllm.core.block.naive_block import NaiveBlock
|
||||
from vllm_vacc.vllm.core.block.naive_block import NaiveBlockAllocator
|
||||
from vllm_vacc.vllm.core.block.prefix_caching_block import PrefixCachingBlockAllocator
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import Device
|
||||
|
||||
from vllm.core.block.cpu_gpu_block_allocator import NullBlock
|
||||
|
||||
class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
|
||||
"""A block allocator that can allocate blocks on both CPU and GPU memory.
|
||||
|
||||
This class implements the `DeviceAwareBlockAllocator` interface and provides
|
||||
functionality for allocating and managing blocks of memory on both CPU and
|
||||
GPU devices.
|
||||
|
||||
The `CpuGpuBlockAllocator` maintains separate memory pools for CPU and GPU
|
||||
blocks, and allows for allocation, deallocation, forking, and swapping of
|
||||
blocks across these memory pools.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
allocator_type: str,
|
||||
num_gpu_blocks: int,
|
||||
num_cpu_blocks: int,
|
||||
block_size: int,
|
||||
) -> DeviceAwareBlockAllocator:
|
||||
"""Creates a CpuGpuBlockAllocator instance with the specified
|
||||
configuration.
|
||||
|
||||
This static method creates and returns a CpuGpuBlockAllocator instance
|
||||
based on the provided parameters. It initializes the CPU and GPU block
|
||||
allocators with the specified number of blocks, block size, and
|
||||
allocator type.
|
||||
|
||||
Args:
|
||||
allocator_type (str): The type of block allocator to use for CPU
|
||||
and GPU blocks. Currently supported values are "naive" and
|
||||
"prefix_caching".
|
||||
num_gpu_blocks (int): The number of blocks to allocate for GPU
|
||||
memory.
|
||||
num_cpu_blocks (int): The number of blocks to allocate for CPU
|
||||
memory.
|
||||
block_size (int): The size of each block in number of tokens.
|
||||
|
||||
Returns:
|
||||
DeviceAwareBlockAllocator: A CpuGpuBlockAllocator instance with the
|
||||
specified configuration.
|
||||
|
||||
Notes:
|
||||
- The block IDs are assigned contiguously, with GPU block IDs coming
|
||||
before CPU block IDs.
|
||||
"""
|
||||
# For HPU, block id 0 is used only for padding
|
||||
reserved_blocks = 1 if current_platform.is_hpu() else 0
|
||||
block_ids = list(
|
||||
range(reserved_blocks, num_gpu_blocks + num_cpu_blocks))
|
||||
num_gpu_blocks -= reserved_blocks
|
||||
gpu_block_ids = block_ids[:num_gpu_blocks]
|
||||
cpu_block_ids = block_ids[num_gpu_blocks:]
|
||||
|
||||
if allocator_type == "naive":
|
||||
gpu_allocator: BlockAllocator = NaiveBlockAllocator(
|
||||
create_block=NaiveBlock, # type: ignore
|
||||
num_blocks=num_gpu_blocks,
|
||||
block_size=block_size,
|
||||
block_ids=gpu_block_ids,
|
||||
)
|
||||
|
||||
cpu_allocator: BlockAllocator = NaiveBlockAllocator(
|
||||
create_block=NaiveBlock, # type: ignore
|
||||
num_blocks=num_cpu_blocks,
|
||||
block_size=block_size,
|
||||
block_ids=cpu_block_ids,
|
||||
)
|
||||
elif allocator_type == "prefix_caching":
|
||||
gpu_allocator = PrefixCachingBlockAllocator(
|
||||
num_blocks=num_gpu_blocks,
|
||||
block_size=block_size,
|
||||
block_ids=gpu_block_ids,
|
||||
)
|
||||
|
||||
cpu_allocator = PrefixCachingBlockAllocator(
|
||||
num_blocks=num_cpu_blocks,
|
||||
block_size=block_size,
|
||||
block_ids=cpu_block_ids,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown allocator type {allocator_type=}")
|
||||
|
||||
return CpuGpuBlockAllocator(
|
||||
cpu_block_allocator=cpu_allocator,
|
||||
gpu_block_allocator=gpu_allocator,
|
||||
)
|
||||
|
||||
def __init__(self, cpu_block_allocator: BlockAllocator,
|
||||
gpu_block_allocator: BlockAllocator):
|
||||
assert not (
|
||||
cpu_block_allocator.all_block_ids
|
||||
& gpu_block_allocator.all_block_ids
|
||||
), "cpu and gpu block allocators can't have intersection of block ids"
|
||||
|
||||
self._allocators = {
|
||||
Device.CPU: cpu_block_allocator,
|
||||
Device.GPU: gpu_block_allocator,
|
||||
}
|
||||
|
||||
self._swap_mapping: Dict[int, int] = {}
|
||||
self._null_block: Optional[Block] = None
|
||||
|
||||
self._block_ids_to_allocator: Dict[int, BlockAllocator] = {}
|
||||
for _, allocator in self._allocators.items():
|
||||
for block_id in allocator.all_block_ids:
|
||||
self._block_ids_to_allocator[block_id] = allocator
|
||||
|
||||
def allocate_or_get_null_block(self) -> Block:
|
||||
if self._null_block is None:
|
||||
self._null_block = NullBlock(
|
||||
self.allocate_mutable_block(None, Device.GPU))
|
||||
return self._null_block
|
||||
|
||||
def allocate_mutable_block(self,
|
||||
prev_block: Optional[Block],
|
||||
device: Device,
|
||||
extra_hash: Optional[int] = None,
|
||||
seq_id: Optional[int] = None) -> Block:
|
||||
"""Allocates a new mutable block on the specified device.
|
||||
|
||||
Args:
|
||||
prev_block (Optional[Block]): The previous block to in the sequence.
|
||||
Used for prefix hashing.
|
||||
device (Device): The device on which to allocate the new block.
|
||||
extra_hash (Optional[int]): The hash value of additional
|
||||
factors, such as adapters, that influence the block hash
|
||||
in the prefix caching block.
|
||||
|
||||
Returns:
|
||||
Block: The newly allocated mutable block.
|
||||
"""
|
||||
return self._allocators[device].allocate_mutable_block(
|
||||
prev_block, extra_hash=extra_hash, seq_id=seq_id)
|
||||
|
||||
def allocate_immutable_blocks(
|
||||
self,
|
||||
prev_block: Optional[Block],
|
||||
block_token_ids: List[List[int]],
|
||||
device: Device,
|
||||
extra_hash: Optional[int] = None,
|
||||
seq_id: Optional[int] = None) -> List[Block]:
|
||||
"""Allocates a new group of immutable blocks with the provided block
|
||||
token IDs on the specified device.
|
||||
|
||||
Args:
|
||||
prev_block (Optional[Block]): The previous block in the sequence.
|
||||
Used for prefix hashing.
|
||||
block_token_ids (List[int]): The list of block token IDs to be
|
||||
stored in the new blocks.
|
||||
device (Device): The device on which to allocate the new block.
|
||||
extra_hash (Optional[int]): The hash value of additional
|
||||
factors, such as adapters, that influence the block hash
|
||||
in the prefix caching block.
|
||||
|
||||
Returns:
|
||||
List[Block]: The newly allocated list of immutable blocks
|
||||
containing the provided block token IDs.
|
||||
"""
|
||||
return self._allocators[device].allocate_immutable_blocks(
|
||||
prev_block, block_token_ids, extra_hash=extra_hash, seq_id=seq_id)
|
||||
|
||||
def allocate_immutable_block(self,
|
||||
prev_block: Optional[Block],
|
||||
token_ids: List[int],
|
||||
device: Device,
|
||||
extra_hash: Optional[int] = None) -> Block:
|
||||
"""Allocates a new immutable block with the provided token IDs on the
|
||||
specified device.
|
||||
|
||||
Args:
|
||||
prev_block (Optional[Block]): The previous block in the sequence.
|
||||
Used for prefix hashing.
|
||||
token_ids (List[int]): The list of token IDs to be stored in the new
|
||||
block.
|
||||
device (Device): The device on which to allocate the new block.
|
||||
extra_hash (Optional[int]): The hash value of additional
|
||||
factors, such as adapters, that influence the block hash
|
||||
in the prefix caching block.
|
||||
|
||||
Returns:
|
||||
Block: The newly allocated immutable block containing the provided
|
||||
token IDs.
|
||||
"""
|
||||
return self._allocators[device].allocate_immutable_block(
|
||||
prev_block, token_ids, extra_hash=extra_hash)
|
||||
|
||||
def free(self, block: Block, seq_id: Optional[int] = None) -> None:
|
||||
"""Frees the memory occupied by the given block.
|
||||
|
||||
Args:
|
||||
block (Block): The block to be freed.
|
||||
"""
|
||||
# Null block should never be freed
|
||||
if isinstance(block, NullBlock):
|
||||
return
|
||||
block_id = block.block_id
|
||||
assert block_id is not None
|
||||
allocator = self._block_ids_to_allocator[block_id]
|
||||
allocator.free(block, seq_id=seq_id)
|
||||
|
||||
def fork(self, last_block: Block) -> List[Block]:
|
||||
"""Creates a new sequence of blocks that shares the same underlying
|
||||
memory as the original sequence.
|
||||
|
||||
Args:
|
||||
last_block (Block): The last block in the original sequence.
|
||||
|
||||
Returns:
|
||||
List[Block]: A new list of blocks that shares the same memory as the
|
||||
original sequence.
|
||||
"""
|
||||
# do not attempt to fork the null block
|
||||
assert not isinstance(last_block, NullBlock)
|
||||
block_id = last_block.block_id
|
||||
assert block_id is not None
|
||||
allocator = self._block_ids_to_allocator[block_id]
|
||||
return allocator.fork(last_block)
|
||||
|
||||
def get_num_free_blocks(self, device: Device, seq_id: int=None) -> int:
|
||||
"""Returns the number of free blocks available on the specified device.
|
||||
|
||||
Args:
|
||||
device (Device): The device for which to query the number of free
|
||||
blocks. AssertionError is raised if None is passed.
|
||||
|
||||
Returns:
|
||||
int: The number of free blocks available on the specified device.
|
||||
"""
|
||||
return self._allocators[device].get_num_free_blocks(seq_id=seq_id)
|
||||
|
||||
def get_num_total_blocks(self, device: Device, seq_id: int=None) -> int:
|
||||
return self._allocators[device].get_num_total_blocks(seq_id=seq_id)
|
||||
|
||||
def get_physical_block_id(self, device: Device, absolute_id: int) -> int:
|
||||
"""Returns the zero-offset block id on certain device given the
|
||||
absolute block id.
|
||||
|
||||
Args:
|
||||
device (Device): The device for which to query relative block id.
|
||||
absolute_id (int): The absolute block id for the block in
|
||||
whole allocator.
|
||||
|
||||
Returns:
|
||||
int: The zero-offset block id on certain device.
|
||||
"""
|
||||
return self._allocators[device].get_physical_block_id(absolute_id)
|
||||
|
||||
def swap(self, blocks: List[Block], src_device: Device,
|
||||
dst_device: Device) -> Dict[int, int]:
|
||||
"""Execute the swap for the given blocks from source_device
|
||||
on to dest_device, save the current swap mapping and append
|
||||
them to the accumulated `self._swap_mapping` for each
|
||||
scheduling move.
|
||||
|
||||
Args:
|
||||
blocks: List of blocks to be swapped.
|
||||
src_device (Device): Device to swap the 'blocks' from.
|
||||
dst_device (Device): Device to swap the 'blocks' to.
|
||||
|
||||
Returns:
|
||||
Dict[int, int]: Swap mapping from source_device
|
||||
on to dest_device.
|
||||
"""
|
||||
src_block_ids = [block.block_id for block in blocks]
|
||||
self._allocators[src_device].swap_out(blocks)
|
||||
self._allocators[dst_device].swap_in(blocks)
|
||||
dst_block_ids = [block.block_id for block in blocks]
|
||||
|
||||
current_swap_mapping: Dict[int, int] = {}
|
||||
for src_block_id, dst_block_id in zip(src_block_ids, dst_block_ids):
|
||||
if src_block_id is not None and dst_block_id is not None:
|
||||
self._swap_mapping[src_block_id] = dst_block_id
|
||||
current_swap_mapping[src_block_id] = dst_block_id
|
||||
return current_swap_mapping
|
||||
|
||||
def get_num_full_blocks_touched(self, blocks: List[Block],
|
||||
device: Device) -> int:
|
||||
"""Returns the number of full blocks that will be touched by
|
||||
swapping in/out the given blocks on to the 'device'.
|
||||
|
||||
Args:
|
||||
blocks: List of blocks to be swapped.
|
||||
device (Device): Device to swap the 'blocks' on.
|
||||
|
||||
Returns:
|
||||
int: the number of full blocks that will be touched by
|
||||
swapping in/out the given blocks on to the 'device'.
|
||||
Non full blocks are ignored when deciding the number
|
||||
of blocks to touch.
|
||||
"""
|
||||
return self._allocators[device].get_num_full_blocks_touched(blocks)
|
||||
|
||||
def clear_copy_on_writes(self, seq_id: Optional[int] = None) -> List[Tuple[int, int]]:
|
||||
"""Clears the copy-on-write (CoW) state and returns the mapping of
|
||||
source to destination block IDs.
|
||||
|
||||
Returns:
|
||||
List[Tuple[int, int]]: A list mapping source block IDs to
|
||||
destination block IDs.
|
||||
"""
|
||||
# CoW only supported on GPU
|
||||
device = Device.GPU
|
||||
return self._allocators[device].clear_copy_on_writes()
|
||||
|
||||
def mark_blocks_as_accessed(self, block_ids: List[int],
|
||||
now: float) -> None:
|
||||
"""Mark blocks as accessed, only use for prefix caching."""
|
||||
# Prefix caching only supported on GPU.
|
||||
device = Device.GPU
|
||||
return self._allocators[device].mark_blocks_as_accessed(block_ids, now)
|
||||
|
||||
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
|
||||
"""Mark blocks as accessed, only use for prefix caching."""
|
||||
# Prefix caching only supported on GPU.
|
||||
device = Device.GPU
|
||||
return self._allocators[device].mark_blocks_as_computed(block_ids)
|
||||
|
||||
def get_common_computed_block_ids(
|
||||
self, computed_seq_block_ids: List[List[int]]) -> List[int]:
|
||||
# Prefix caching only supported on GPU.
|
||||
device = Device.GPU
|
||||
return self._allocators[device].get_common_computed_block_ids(
|
||||
computed_seq_block_ids)
|
||||
|
||||
@property
|
||||
def all_block_ids(self) -> FrozenSet[int]:
|
||||
return frozenset(self._block_ids_to_allocator.keys())
|
||||
|
||||
def get_prefix_cache_hit_rate(self, device: Device) -> float:
|
||||
"""Prefix cache hit rate. -1 means not supported or disabled."""
|
||||
assert device in self._allocators
|
||||
return self._allocators[device].get_prefix_cache_hit_rate()
|
||||
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
"""Reset prefix cache for all devices."""
|
||||
success = True
|
||||
for allocator in self._allocators.values():
|
||||
success = success and allocator.reset_prefix_cache()
|
||||
return success
|
||||
|
||||
def get_and_reset_swaps(self) -> List[Tuple[int, int]]:
|
||||
"""Returns and clears the mapping of source to destination block IDs.
|
||||
Will be called after every swapping operations for now, and after every
|
||||
schedule when BlockManagerV2 become default. Currently not useful.
|
||||
|
||||
Returns:
|
||||
List[Tuple[int, int]]: A mapping of source to destination block IDs.
|
||||
"""
|
||||
mapping = self._swap_mapping.copy()
|
||||
self._swap_mapping.clear()
|
||||
return list(mapping.items())
|
||||
|
||||
def find_cached_blocks_prefix(
|
||||
self,
|
||||
block_hashes: List[int],
|
||||
device: Device = Device.GPU,
|
||||
) -> List[int]:
|
||||
return self._allocators[device].find_cached_blocks_prefix(block_hashes)
|
||||
464
vllm_vacc/vllm/core/block/naive_block.py
Normal file
464
vllm_vacc/vllm/core/block/naive_block.py
Normal file
@@ -0,0 +1,464 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from collections import deque
|
||||
from typing import Deque, FrozenSet, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
from vllm.core.block.common import (BlockPool, CopyOnWriteTracker, RefCounter,
|
||||
get_all_blocks_recursively)
|
||||
from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device
|
||||
|
||||
from typing import Dict
|
||||
from vllm.logger import init_logger
|
||||
import os
|
||||
|
||||
max_seq_num = int(os.getenv("MAX_SEQ_NUM", 4))
|
||||
from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size
|
||||
|
||||
Refcount = int
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
class NaiveBlockAllocator(BlockAllocator):
|
||||
"""A simple block allocator that manages blocks of memory without prefix
|
||||
caching.
|
||||
|
||||
Args:
|
||||
create_block (Block.Factory): A factory function for creating new
|
||||
blocks. This is used when a NaiveBlockAllocator is composed within
|
||||
a prefix caching allocator -- the naive block allocator must
|
||||
construct prefix caching blocks (but shouldn't know anything else
|
||||
about them).
|
||||
num_blocks (int): The total number of blocks to manage.
|
||||
block_size (int): The size of each block in tokens.
|
||||
block_ids (Optional[Iterable[int]], optional): An optional iterable of
|
||||
block IDs. If not provided, block IDs will be assigned sequentially
|
||||
from 0 to num_blocks - 1.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
create_block: Block.Factory,
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
block_ids: Optional[Iterable[int]] = None,
|
||||
block_pool: Optional[BlockPool] = None,
|
||||
):
|
||||
# new mapping seqid : block_group_id
|
||||
self.is_partitioned = False
|
||||
self.num_blocks = num_blocks
|
||||
self.seq_mapping: Dict[int, List[int]] = {}
|
||||
|
||||
if block_ids is None:
|
||||
block_ids = range(num_blocks)
|
||||
|
||||
self._free_block_indices_all: Deque[BlockId] = deque(block_ids)
|
||||
self._all_block_indices = frozenset(block_ids)
|
||||
assert len(self._all_block_indices) == num_blocks
|
||||
|
||||
self._refcounter = RefCounter(
|
||||
all_block_indices=self._free_block_indices_all)
|
||||
self._block_size = block_size
|
||||
|
||||
self._cow_tracker = CopyOnWriteTracker(
|
||||
refcounter=self._refcounter.as_readonly())
|
||||
|
||||
if block_pool is None:
|
||||
extra_factor = 4
|
||||
# Pre-allocate "num_blocks * extra_factor" block objects.
|
||||
# The "* extra_factor" is a buffer to allow more block objects
|
||||
# than physical blocks
|
||||
self._block_pool = BlockPool(self._block_size, create_block, self,
|
||||
num_blocks * extra_factor)
|
||||
else:
|
||||
# In this case, the block pool is provided by the caller,
|
||||
# which means that there is most likely a need to share
|
||||
# a block pool between allocators
|
||||
self._block_pool = block_pool
|
||||
|
||||
# partition blocks to block groups
|
||||
self.partition_blocks()
|
||||
|
||||
def allocate_immutable_block(self,
|
||||
prev_block: Optional[Block],
|
||||
token_ids: List[int],
|
||||
extra_hash: Optional[int] = None,
|
||||
device: Optional[Device] = None) -> Block:
|
||||
"""Allocates a new immutable block with the given token IDs, linked to
|
||||
the previous block.
|
||||
|
||||
Args:
|
||||
prev_block (Optional[Block]): The previous block in the sequence. If
|
||||
None, then the block to be allocated is the first block in the
|
||||
sequence.
|
||||
token_ids (List[int]): The token IDs to be stored in the new block.
|
||||
|
||||
Returns:
|
||||
Block: The newly allocated immutable block.
|
||||
"""
|
||||
assert device is None
|
||||
block = self.allocate_mutable_block(prev_block=prev_block)
|
||||
block.append_token_ids(token_ids)
|
||||
return block
|
||||
|
||||
def allocate_immutable_blocks(
|
||||
self,
|
||||
prev_block: Optional[Block],
|
||||
block_token_ids: List[List[int]],
|
||||
extra_hash: Optional[int] = None,
|
||||
device: Optional[Device] = None,
|
||||
seq_id: Optional[int] = None) -> List[Block]:
|
||||
assert device is None
|
||||
num_blocks = len(block_token_ids)
|
||||
|
||||
block_ids = []
|
||||
for i in range(num_blocks):
|
||||
block_ids.append(self._allocate_block_id(seq_id=seq_id))
|
||||
|
||||
blocks = []
|
||||
for i in range(num_blocks):
|
||||
prev_block = self._block_pool.init_block(
|
||||
prev_block=prev_block,
|
||||
token_ids=block_token_ids[i],
|
||||
block_size=self._block_size,
|
||||
physical_block_id=block_ids[i])
|
||||
blocks.append(prev_block)
|
||||
|
||||
return blocks
|
||||
|
||||
def allocate_mutable_block(self,
|
||||
prev_block: Optional[Block],
|
||||
extra_hash: Optional[int] = None,
|
||||
device: Optional[Device] = None,
|
||||
seq_id: Optional[int] = None) -> Block:
|
||||
"""Allocates a new mutable block, linked to the previous block.
|
||||
|
||||
Args:
|
||||
prev_block (Optional[Block]): The previous block in the sequence. If
|
||||
None, then the block to be allocated is the first block in the
|
||||
sequence.
|
||||
|
||||
Returns:
|
||||
Block: The newly allocated mutable block.
|
||||
"""
|
||||
assert device is None
|
||||
assert seq_id is not None
|
||||
block_id = self._allocate_block_id(seq_id)
|
||||
block = self._block_pool.init_block(prev_block=prev_block,
|
||||
token_ids=[],
|
||||
block_size=self._block_size,
|
||||
physical_block_id=block_id)
|
||||
return block
|
||||
|
||||
def _allocate_block_id(self, seq_id: Optional[int] = None) -> BlockId:
|
||||
assert seq_id is not None
|
||||
# always use the lastest grp id to allocate
|
||||
# since the previous grps are exausted
|
||||
grp_id_list = self.get_group_list(seq_id)
|
||||
if len(grp_id_list) == 0 or len(self._free_block_indices[grp_id_list[-1]]) == 0:
|
||||
if not self._free_block_grp_indices:
|
||||
# no more block id in block group pool
|
||||
# should not reach here
|
||||
raise False
|
||||
# raise BlockAllocator.NoFreeBlocksError()
|
||||
else:
|
||||
# pop a new block and add to seq_mapping
|
||||
grp_id = self._free_block_grp_indices.popleft()
|
||||
grp_id_list.append(grp_id)
|
||||
self.seq_mapping[seq_id] = grp_id_list
|
||||
grp_id = grp_id_list[-1]
|
||||
block_id = self._free_block_indices[grp_id].popleft()
|
||||
self._refcounter.incr(block_id)
|
||||
return block_id
|
||||
|
||||
def _free_block_id(self, block: Union[Block, BlockId], seq_id: Optional[int] = None) -> None:
|
||||
assert seq_id is not None
|
||||
grp_id_list = self.get_group_list(seq_id)
|
||||
if isinstance(block, Block):
|
||||
block_id = block.block_id
|
||||
block.block_id = None
|
||||
else:
|
||||
block_id = block
|
||||
assert block_id is not None
|
||||
|
||||
# block_id should always be in grp_id_list[0]
|
||||
# since the block id is freed in block id order
|
||||
grp_id = grp_id_list[-1]
|
||||
assert block_id in self._block_grp_indices[grp_id], f"grp_id: {grp_id} block_id:{block_id}"
|
||||
|
||||
refcount = self._refcounter.decr(block_id)
|
||||
if refcount == 0:
|
||||
self._free_block_indices[grp_id].appendleft(block_id)
|
||||
if len(self._free_block_indices[grp_id]) == len(self._block_grp_indices[grp_id]):
|
||||
# free group
|
||||
self.seq_mapping[seq_id].remove(grp_id)
|
||||
if len(self.seq_mapping[seq_id]) == 0:
|
||||
# free seq_id
|
||||
del self.seq_mapping[seq_id]
|
||||
# collect back to block group pool
|
||||
self._free_block_grp_indices.appendleft(grp_id)
|
||||
|
||||
def free(self, block: Block, keep_block_object: bool = False, seq_id: Optional[int] = None) -> None:
|
||||
# Release the physical block id
|
||||
self._free_block_id(block, seq_id=seq_id)
|
||||
|
||||
# Release the block object
|
||||
if not keep_block_object:
|
||||
self._block_pool.free_block(block)
|
||||
|
||||
def free_block_id(self, block_id: BlockId, seq_id: Optional[int] = None) -> None:
|
||||
assert seq_id is not None
|
||||
self._free_block_id(block_id, seq_id)
|
||||
|
||||
def fork(self, last_block: Block, seq_id: Optional[int] = None) -> List[Block]:
|
||||
"""Creates a new sequence of blocks that shares the same underlying
|
||||
memory as the original sequence.
|
||||
|
||||
Args:
|
||||
last_block (Block): The last block in the original sequence.
|
||||
|
||||
Returns:
|
||||
List[Block]: The new sequence of blocks that shares the same memory
|
||||
as the original sequence.
|
||||
"""
|
||||
source_blocks = get_all_blocks_recursively(last_block)
|
||||
|
||||
forked_blocks: List[Block] = []
|
||||
prev_block = None
|
||||
grp_id_list = self.get_group_list(seq_id)
|
||||
for block in source_blocks:
|
||||
# Increment refcount for each block.
|
||||
assert block.block_id is not None
|
||||
grp_id = self.get_group_id(block.block_id, grp_id_list)
|
||||
assert grp_id != -1, "can't locate block group"
|
||||
refcount = self._refcounter.incr(block.block_id)
|
||||
assert refcount != 1, "can't fork free'd block"
|
||||
|
||||
forked_block = self._block_pool.init_block(
|
||||
prev_block=prev_block,
|
||||
token_ids=block.token_ids,
|
||||
block_size=self._block_size,
|
||||
physical_block_id=block.block_id)
|
||||
|
||||
forked_blocks.append(forked_block)
|
||||
prev_block = forked_blocks[-1]
|
||||
|
||||
return forked_blocks
|
||||
|
||||
def partition_blocks(self) -> None:
|
||||
# only parition once in each vllm server lifecycle
|
||||
if self.is_partitioned: #and len(self.seq_mapping) > 0:
|
||||
return
|
||||
|
||||
self.is_partitioned = True
|
||||
self._blk_num_per_grp = env_blk_grp_size // self._block_size
|
||||
self._all_blk_grp_num = self.num_blocks // self._blk_num_per_grp
|
||||
|
||||
block_groups = []
|
||||
for i in range(self._all_blk_grp_num):
|
||||
start = i * self._blk_num_per_grp
|
||||
block_groups.append([k for k in range(start, start + self._blk_num_per_grp)])
|
||||
|
||||
self._free_block_grp_indices: Deque[BlockId] = deque(range(self._all_blk_grp_num))
|
||||
# self._free_block_grp_indices: Deque[BlockId] = deque(range(self._all_blk_grp_num-1,-1,-1))
|
||||
self._free_block_indices: List[Deque[BlockId]] = [deque(block_ids) for block_ids in block_groups]
|
||||
self._block_grp_indices: List[FrozenSet] = [frozenset(block_ids) for block_ids in block_groups]
|
||||
# self._all_block_indices =[frozenset(block_ids) for block_ids in block_groups]
|
||||
|
||||
# get group id list according to block_id
|
||||
def get_group_id(self, block_id, grp_id_list) -> int:
|
||||
for i in grp_id_list:
|
||||
if block_id in self._block_grp_indices[i]:
|
||||
return i
|
||||
assert False
|
||||
# return -1
|
||||
|
||||
# get group id list according to seq_id
|
||||
def get_group_list(self, seq_id) -> int:
|
||||
"""Get group id list acoording to current seq_id
|
||||
key: seq_id, value: [grp_id, grp_id, ...]
|
||||
"""
|
||||
assert seq_id is not None
|
||||
grp_id_list = []
|
||||
if seq_id in self.seq_mapping:
|
||||
grp_id_list = self.seq_mapping[seq_id]
|
||||
return grp_id_list
|
||||
|
||||
def get_num_free_blocks(self, seq_id: Optional[int] = None) -> int:
|
||||
free_blocks = len(self._free_block_grp_indices) * self._blk_num_per_grp
|
||||
if seq_id is not None:
|
||||
if seq_id in self.seq_mapping:
|
||||
# seq_id is already allocated
|
||||
grp_id_list = self.seq_mapping[seq_id]
|
||||
free_blocks += len(self._free_block_indices[grp_id_list[-1]])
|
||||
else:
|
||||
# new seq_id
|
||||
if len(self.seq_mapping) >= max_seq_num:
|
||||
return 0
|
||||
# means real allocate, only consider free 8K block groups
|
||||
return free_blocks
|
||||
else:
|
||||
# for memory usage analysis / swap
|
||||
# need consider also the free blocks that in 8K blocks groups
|
||||
for _, grp_id_list in self.seq_mapping.items():
|
||||
if len(grp_id_list) > 0:
|
||||
free_blocks += len(self._free_block_indices[grp_id_list[-1]])
|
||||
return free_blocks
|
||||
|
||||
def get_num_total_blocks(self, seq_id: Optional[int] = None) -> int:
|
||||
return len(self._all_block_indices)
|
||||
|
||||
def get_physical_block_id(self, absolute_id: int) -> int:
|
||||
"""Returns the zero-offset block id on certain block allocator
|
||||
given the absolute block id.
|
||||
|
||||
Args:
|
||||
absolute_id (int): The absolute block id for the block
|
||||
in whole allocator.
|
||||
|
||||
Returns:
|
||||
int: The zero-offset block id on certain device.
|
||||
"""
|
||||
return sorted(self._all_block_indices).index(absolute_id)
|
||||
|
||||
@property
|
||||
def refcounter(self):
|
||||
return self._refcounter
|
||||
|
||||
@property
|
||||
def all_block_ids(self) -> FrozenSet[int]:
|
||||
return self._all_block_indices
|
||||
|
||||
def cow_block_if_not_appendable(self, block: Block, seq_id: Optional[int] = None) -> BlockId:
|
||||
"""Performs a copy-on-write operation on the given block if it is not
|
||||
appendable.
|
||||
|
||||
Args:
|
||||
block (Block): The block to check for copy-on-write.
|
||||
|
||||
Returns:
|
||||
BlockId: The block index of the new block if a copy-on-write
|
||||
operation was performed, or the original block index if
|
||||
no copy-on-write was necessary.
|
||||
"""
|
||||
src_block_id = block.block_id
|
||||
assert src_block_id is not None
|
||||
|
||||
if self._cow_tracker.is_appendable(block):
|
||||
return src_block_id
|
||||
|
||||
self._free_block_id(block, seq_id)
|
||||
trg_block_id = self._allocate_block_id()
|
||||
|
||||
self._cow_tracker.record_cow(src_block_id, trg_block_id)
|
||||
|
||||
return trg_block_id
|
||||
|
||||
def clear_copy_on_writes(self, seq_id: Optional[int] = None) -> List[Tuple[BlockId, BlockId]]:
|
||||
"""Returns the copy-on-write source->destination mapping and clears it.
|
||||
|
||||
Returns:
|
||||
List[Tuple[BlockId, BlockId]]: A list mapping source
|
||||
block indices to destination block indices.
|
||||
"""
|
||||
return self._cow_tracker.clear_cows()
|
||||
|
||||
def mark_blocks_as_accessed(self, block_ids: List[int],
|
||||
now: float) -> None:
|
||||
"""Mark blocks as accessed, used in prefix caching.
|
||||
|
||||
Since the naive allocator does not implement prefix caching, we do
|
||||
nothing.
|
||||
"""
|
||||
pass
|
||||
|
||||
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
|
||||
"""Mark blocks as computed, used in prefix caching.
|
||||
|
||||
Since the naive allocator does not implement prefix caching, we do
|
||||
nothing.
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_common_computed_block_ids(
|
||||
self, computed_seq_block_ids: List[List[int]]) -> List[int]:
|
||||
"""Determine blocks that can be skipped in prefill.
|
||||
|
||||
Since the naive allocator does not support prefix caching, always return
|
||||
an empty list.
|
||||
"""
|
||||
return []
|
||||
|
||||
def promote_to_immutable_block(self, block: Block) -> BlockId:
|
||||
raise NotImplementedError("There is no promotion for naive blocks")
|
||||
|
||||
def get_num_full_blocks_touched(self, blocks: List[Block]) -> int:
|
||||
"""Returns the number of full blocks that will be touched by
|
||||
swapping in/out.
|
||||
|
||||
Args:
|
||||
blocks: List of blocks to be swapped.
|
||||
Returns:
|
||||
int: the number of full blocks that will be touched by
|
||||
swapping in/out the given blocks. Non full blocks are ignored
|
||||
when deciding the number of blocks to touch.
|
||||
"""
|
||||
# NOTE: for naive block, we use set to eliminate common blocks among
|
||||
# seqs, also we compare the empty slots in the mutable blocks with
|
||||
# lookahead slots to get the number of unique new block that are
|
||||
# needed.
|
||||
old_block_set = set()
|
||||
for block in blocks:
|
||||
if block.is_full:
|
||||
old_block_set.add(block)
|
||||
return len(old_block_set)
|
||||
|
||||
def swap_out(self, blocks: List[Block], seq_id: Optional[int] = None) -> None:
|
||||
for block in blocks:
|
||||
self._free_block_id(block, seq_id)
|
||||
|
||||
def swap_in(self, blocks: List[Block]) -> None:
|
||||
for block in blocks:
|
||||
# Here we allocate either immutable or mutable block and then
|
||||
# extract its block_id. Note that the block object is released
|
||||
# and the block_id is assigned to "block" to allow reusing the
|
||||
# existing "block" object
|
||||
if block.is_full:
|
||||
tmp_block = self.allocate_immutable_block(
|
||||
prev_block=block.prev_block, token_ids=block.token_ids)
|
||||
else:
|
||||
tmp_block = self.allocate_mutable_block(
|
||||
prev_block=block.prev_block)
|
||||
tmp_block.append_token_ids(block.token_ids)
|
||||
|
||||
block_id = tmp_block.block_id
|
||||
tmp_block.block_id = None
|
||||
self._block_pool.free_block(tmp_block)
|
||||
|
||||
block.block_id = block_id # Assign block_id
|
||||
|
||||
def get_prefix_cache_hit_rate(self) -> float:
|
||||
return -1
|
||||
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
"""No prefix cache for naive block allocator."""
|
||||
return True
|
||||
|
||||
def find_cached_blocks_prefix(self, block_hashes: List[int]) -> List[int]:
|
||||
# Not applicable for naive block allocator.
|
||||
return []
|
||||
|
||||
class NaiveBlock(Block):
|
||||
def append_token_ids(self, token_ids: List[int], seq_id: Optional[int] = None) -> None:
|
||||
"""Appends the given token IDs to the block and performs a
|
||||
copy-on-write if necessary.
|
||||
|
||||
Args:
|
||||
token_ids (Optional[List[int]]): The token IDs to be appended
|
||||
to the block.
|
||||
"""
|
||||
assert seq_id is not None
|
||||
self._append_token_ids_no_cow(token_ids)
|
||||
|
||||
if self._block_id is not None:
|
||||
self._block_id = (self._allocator.cow_block_if_not_appendable(
|
||||
self._cow_target, seq_id))
|
||||
942
vllm_vacc/vllm/core/block/prefix_caching_block.py
Normal file
942
vllm_vacc/vllm/core/block/prefix_caching_block.py
Normal file
@@ -0,0 +1,942 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Token blocks."""
|
||||
import sys
|
||||
from bisect import bisect_left
|
||||
from os.path import commonprefix
|
||||
from typing import (Callable, Dict, FrozenSet, Iterable, List, Optional, Set,
|
||||
Tuple)
|
||||
|
||||
from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker,
|
||||
get_all_blocks_recursively)
|
||||
from vllm.core.block.interfaces import (Block, BlockAllocator, BlockId, Device,
|
||||
DeviceAwareBlockAllocator)
|
||||
from vllm.core.block.naive_block import (BlockPool, NaiveBlock)
|
||||
from vllm_vacc.vllm.core.block.naive_block import NaiveBlockAllocator
|
||||
from vllm.core.block.prefix_caching_block import BlockTracker, assert_prefix_caching_block_or_none
|
||||
|
||||
from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import Sequence
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
PrefixHash = int
|
||||
|
||||
# By default, we init our block access time as _DEFAULT_LAST_ACCESSED_TIME
|
||||
# so that if we find one block is still hold _DEFAULT_LAST_ACCESSED_TIME,
|
||||
# then we know this block hasn't been accessed yet.
|
||||
_DEFAULT_LAST_ACCESSED_TIME = -1
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class PrefixCachingBlockAllocator(BlockAllocator):
|
||||
"""A block allocator that implements prefix caching.
|
||||
|
||||
The PrefixCachingBlockAllocator maintains a cache of blocks based on their
|
||||
content hash. It reuses blocks with the same content hash to avoid redundant
|
||||
memory allocation. The allocator also supports copy-on-write operations.
|
||||
|
||||
Args:
|
||||
num_blocks (int): The total number of blocks to manage.
|
||||
block_size (int): The size of each block in tokens.
|
||||
block_ids(Optional[Iterable[int]], optional): An optional iterable of
|
||||
block IDs. If not provided, block IDs will be assigned sequentially
|
||||
from 0 to num_blocks - 1.
|
||||
"""
|
||||
|
||||
# Note that we use 'None' as a string here instead of None because
|
||||
# as of Python 3.12, hash(None) returns a constant predictable value.
|
||||
# This could possibly make it easier to find and exploit hash
|
||||
# collisions. 'None' as a string will be hashed differently per process,
|
||||
# but consistently within the same process. This is the same as the
|
||||
# behavior of None prior to Python 3.12.
|
||||
_none_hash: int = hash('None')
|
||||
|
||||
# Implements Block.Factory.
|
||||
def __init__(
|
||||
self,
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
block_ids: Optional[Iterable[int]] = None,
|
||||
eviction_policy: EvictionPolicy = EvictionPolicy.LRU,
|
||||
):
|
||||
if block_ids is None:
|
||||
block_ids = range(num_blocks)
|
||||
|
||||
self._block_size = block_size
|
||||
|
||||
# A mapping of prefix hash to block index. All blocks which have a
|
||||
# prefix hash will be in this dict, even if they have refcount 0.
|
||||
self._cached_blocks: Dict[PrefixHash, BlockId] = {}
|
||||
|
||||
# A list of immutable block IDs that have been touched by scheduler
|
||||
# and should be marked as computed after an entire batch of sequences
|
||||
# are scheduled.
|
||||
self._touched_blocks: Set[BlockId] = set()
|
||||
|
||||
# Used to track status of each physical block id
|
||||
self._block_tracker: Dict[BlockId, BlockTracker] = {}
|
||||
for block_id in block_ids:
|
||||
self._block_tracker[block_id] = BlockTracker()
|
||||
|
||||
# Pre-allocate "num_blocks * extra_factor" block objects.
|
||||
# The "* extra_factor" is a buffer to allow more block objects
|
||||
# than physical blocks
|
||||
extra_factor = 4
|
||||
self._block_pool = BlockPool(self._block_size, self._create_block,
|
||||
self, num_blocks * extra_factor)
|
||||
|
||||
# An allocator for blocks that do not have prefix hashes.
|
||||
self._hashless_allocator = NaiveBlockAllocator(
|
||||
create_block=self._create_block, # type: ignore
|
||||
num_blocks=num_blocks,
|
||||
block_size=block_size,
|
||||
block_ids=block_ids,
|
||||
block_pool=self._block_pool, # Share block pool here
|
||||
)
|
||||
|
||||
# Evitor used to maintain how we want to handle those computed blocks
|
||||
# if we find memory pressure is high.
|
||||
self.eviction_policy = eviction_policy
|
||||
self.evictor: Evictor = make_evictor(self.eviction_policy)
|
||||
|
||||
# We share the refcounter between allocators. This allows us to promote
|
||||
# blocks originally allocated in the hashless allocator to immutable
|
||||
# blocks.
|
||||
self._refcounter = self._hashless_allocator.refcounter
|
||||
|
||||
self._cow_tracker = CopyOnWriteTracker(
|
||||
refcounter=self._refcounter.as_readonly())
|
||||
|
||||
self.metric_data = CacheMetricData()
|
||||
|
||||
def _create_block(
|
||||
self,
|
||||
prev_block: Optional[Block],
|
||||
token_ids: List[int],
|
||||
block_size: int,
|
||||
allocator: BlockAllocator,
|
||||
block_id: Optional[int] = None,
|
||||
computed: bool = False,
|
||||
extra_hash: Optional[int] = None,
|
||||
) -> Block:
|
||||
# Bind block to self.
|
||||
allocator = self
|
||||
|
||||
return PrefixCachingBlock(
|
||||
prev_block=prev_block,
|
||||
token_ids=token_ids,
|
||||
block_size=block_size,
|
||||
block_id=block_id,
|
||||
allocator=allocator,
|
||||
computed=computed,
|
||||
extra_hash=extra_hash,
|
||||
)
|
||||
|
||||
def allocate_immutable_block(self,
|
||||
prev_block: Optional[Block],
|
||||
token_ids: List[int],
|
||||
extra_hash: Optional[int] = None,
|
||||
device: Optional[Device] = None,
|
||||
seq_id: Optional[int] = None) -> Block:
|
||||
"""Allocates an immutable block with the given token IDs, reusing cached
|
||||
blocks if possible.
|
||||
|
||||
Args:
|
||||
prev_block (Optional[Block]): The previous block in the sequence.
|
||||
token_ids (List[int]): The token IDs to be stored in the block.
|
||||
|
||||
Returns:
|
||||
Block: The allocated immutable block.
|
||||
"""
|
||||
assert device is None
|
||||
assert_prefix_caching_block_or_none(prev_block)
|
||||
|
||||
# First, try to create a block that points to cached data
|
||||
block = self._block_pool.init_block(prev_block=prev_block,
|
||||
token_ids=token_ids,
|
||||
block_size=self._block_size,
|
||||
physical_block_id=None,
|
||||
extra_hash=extra_hash)
|
||||
assert block.content_hash is not None
|
||||
|
||||
cached_block_id = self._cached_blocks.get(block.content_hash, None)
|
||||
if cached_block_id is not None:
|
||||
self.metric_data.query(hit=True)
|
||||
block.block_id = cached_block_id
|
||||
self._incr_refcount_cached_block(block)
|
||||
return block
|
||||
self.metric_data.query(hit=False)
|
||||
self._block_pool.free_block(block)
|
||||
|
||||
# No cached block => Allocate a new block
|
||||
block = self.allocate_mutable_block(prev_block, extra_hash=extra_hash, seq_id=seq_id)
|
||||
logger.warning(f"Teng seq_id: {seq_id} block: {block.block_id} hash: {block.content_hash}")
|
||||
|
||||
block.append_token_ids(token_ids, seq_id=seq_id)
|
||||
return block
|
||||
|
||||
def allocate_immutable_blocks(
|
||||
self,
|
||||
prev_block: Optional[Block],
|
||||
block_token_ids: List[List[int]],
|
||||
extra_hash: Optional[int] = None,
|
||||
device: Optional[Device] = None,
|
||||
seq_id: Optional[int] = None) -> List[Block]:
|
||||
blocks = []
|
||||
for token_ids in block_token_ids:
|
||||
prev_block = self.allocate_immutable_block(prev_block=prev_block,
|
||||
token_ids=token_ids,
|
||||
device=device,
|
||||
extra_hash=extra_hash,
|
||||
seq_id=seq_id)
|
||||
blocks.append(prev_block)
|
||||
return blocks
|
||||
|
||||
def allocate_mutable_block(self,
|
||||
prev_block: Optional[Block],
|
||||
extra_hash: Optional[int] = None,
|
||||
device: Optional[Device] = None,
|
||||
seq_id: Optional[int] = None) -> Block:
|
||||
"""Allocates a mutable block. If there are no free blocks, this will
|
||||
evict unused cached blocks.
|
||||
|
||||
Args:
|
||||
prev_block (Block): The previous block in the sequence.
|
||||
None is not allowed unlike it is super class.
|
||||
|
||||
Returns:
|
||||
Block: The allocated mutable block.
|
||||
"""
|
||||
assert device is None
|
||||
assert seq_id is not None
|
||||
assert_prefix_caching_block_or_none(prev_block)
|
||||
|
||||
block_id = self._allocate_block_id(seq_id)
|
||||
block = self._block_pool.init_block(prev_block=prev_block,
|
||||
token_ids=[],
|
||||
block_size=self._block_size,
|
||||
physical_block_id=block_id,
|
||||
extra_hash=extra_hash)
|
||||
assert not block.computed
|
||||
assert block.content_hash is None
|
||||
return block
|
||||
|
||||
def _incr_refcount_cached_block(self, block: Block) -> None:
|
||||
# Set this block to be "computed" since it is pointing to a
|
||||
# cached block id (which was already computed)
|
||||
block.computed = True
|
||||
|
||||
block_id = block.block_id
|
||||
assert block_id is not None
|
||||
|
||||
refcount = self._refcounter.incr(block_id)
|
||||
if refcount == 1:
|
||||
# In case a cached block was evicted, restore its tracking
|
||||
if block_id in self.evictor:
|
||||
self.evictor.remove(block_id)
|
||||
|
||||
self._track_block_id(block_id, computed=True)
|
||||
|
||||
def _decr_refcount_cached_block(self, block: Block) -> None:
|
||||
# Ensure this is immutable/cached block
|
||||
assert block.content_hash is not None
|
||||
|
||||
block_id = block.block_id
|
||||
assert block_id is not None
|
||||
|
||||
refcount = self._refcounter.decr(block_id)
|
||||
if refcount > 0:
|
||||
block.block_id = None
|
||||
return
|
||||
else:
|
||||
assert refcount == 0
|
||||
|
||||
# No longer used
|
||||
assert block.content_hash in self._cached_blocks
|
||||
|
||||
# Add the cached block to the evictor
|
||||
# (This keeps the cached block around so it can be reused)
|
||||
self.evictor.add(block_id, block.content_hash, block.num_tokens_total,
|
||||
self._block_tracker[block_id].last_accessed)
|
||||
|
||||
# Stop tracking the block
|
||||
self._untrack_block_id(block_id)
|
||||
|
||||
block.block_id = None
|
||||
|
||||
def _decr_refcount_hashless_block(self, block: Block, seq_id: Optional[int] = None) -> None:
|
||||
block_id = block.block_id
|
||||
assert block_id is not None
|
||||
|
||||
# We may have a fork case where block is shared,
|
||||
# in which case, we cannot remove it from tracking
|
||||
refcount = self._refcounter.get(block_id)
|
||||
if refcount == 1:
|
||||
self._untrack_block_id(block_id)
|
||||
|
||||
# Decrement refcount of the block_id, but do not free the block object
|
||||
# itself (will be handled by the caller)
|
||||
self._hashless_allocator.free(block, keep_block_object=True, seq_id=seq_id)
|
||||
|
||||
def _allocate_block_id(self, seq_id: Optional[int] = None) -> BlockId:
|
||||
"""First tries to allocate a block id from the hashless allocator,
|
||||
and if there are no blocks, then tries to evict an unused cached block.
|
||||
"""
|
||||
assert seq_id is not None
|
||||
hashless_block_id = self._maybe_allocate_hashless_block_id(seq_id=seq_id)
|
||||
if hashless_block_id is not None:
|
||||
return hashless_block_id
|
||||
|
||||
evicted_block_id = self._maybe_allocate_evicted_block_id()
|
||||
if evicted_block_id is not None:
|
||||
return evicted_block_id
|
||||
|
||||
# No block available in hashless allocator, nor in unused cache blocks.
|
||||
raise BlockAllocator.NoFreeBlocksError()
|
||||
|
||||
def _maybe_allocate_hashless_block_id(self, seq_id: Optional[int] = None) -> Optional[BlockId]:
|
||||
try:
|
||||
# Allocate mutable block and extract its block_id
|
||||
block = self._hashless_allocator.allocate_mutable_block(
|
||||
prev_block=None, seq_id=seq_id)
|
||||
block_id = block.block_id
|
||||
self._block_pool.free_block(block)
|
||||
|
||||
self._track_block_id(block_id, computed=False)
|
||||
return block_id
|
||||
except BlockAllocator.NoFreeBlocksError:
|
||||
return None
|
||||
|
||||
def _maybe_allocate_evicted_block_id(self) -> Optional[BlockId]:
|
||||
if self.evictor.num_blocks == 0:
|
||||
return None
|
||||
|
||||
# Here we get an evicted block, which is only added
|
||||
# into evictor if its ref counter is 0
|
||||
# and since its content would be changed, we need
|
||||
# to remove it from _cached_blocks's tracking list
|
||||
block_id, content_hash_to_evict = self.evictor.evict()
|
||||
|
||||
# Sanity checks
|
||||
assert content_hash_to_evict in self._cached_blocks
|
||||
_block_id = self._cached_blocks[content_hash_to_evict]
|
||||
assert self._refcounter.get(_block_id) == 0
|
||||
assert _block_id == block_id
|
||||
|
||||
self._cached_blocks.pop(content_hash_to_evict)
|
||||
|
||||
self._refcounter.incr(block_id)
|
||||
self._track_block_id(block_id, computed=False)
|
||||
|
||||
return block_id
|
||||
|
||||
def _free_block_id(self, block: Block, seq_id: Optional[int] = None) -> None:
|
||||
"""Decrements the refcount of the block. The block may be in two
|
||||
possible states: (1) immutable/cached or (2) mutable/hashless.
|
||||
In the first case, the refcount is decremented directly and the block
|
||||
may be possibly added to the evictor. In other case, hashless
|
||||
allocator free(..) with keep_block_object=True is called to only free
|
||||
the block id (since the block object may be reused by the caller)
|
||||
"""
|
||||
block_id = block.block_id
|
||||
assert block_id is not None, "Freeing unallocated block is undefined"
|
||||
|
||||
if block.content_hash is not None:
|
||||
# Immutable: This type of block is always cached, and we want to
|
||||
# keep it in the evictor for future reuse
|
||||
self._decr_refcount_cached_block(block)
|
||||
else:
|
||||
# Mutable: This type of block is not cached, so we release it
|
||||
# directly to the hashless allocator
|
||||
self._decr_refcount_hashless_block(block, seq_id=seq_id)
|
||||
|
||||
assert block.block_id is None
|
||||
|
||||
def free(self, block: Block, keep_block_object: bool = False, seq_id: Optional[int] = None) -> None:
|
||||
"""Release the block (look at free_block_id(..) docs)
|
||||
"""
|
||||
# Release the physical block index
|
||||
self._free_block_id(block, seq_id=seq_id)
|
||||
|
||||
# Release the block object to the pool
|
||||
if not keep_block_object:
|
||||
self._block_pool.free_block(block)
|
||||
|
||||
def fork(self, last_block: Block) -> List[Block]:
|
||||
"""Creates a new sequence of blocks that shares the same underlying
|
||||
memory as the original sequence.
|
||||
|
||||
Args:
|
||||
last_block (Block): The last block in the original sequence.
|
||||
|
||||
Returns:
|
||||
List[Block]: The new sequence of blocks that shares the same memory
|
||||
as the original sequence.
|
||||
"""
|
||||
source_blocks = get_all_blocks_recursively(last_block)
|
||||
|
||||
forked_blocks: List[Block] = []
|
||||
prev_block = None
|
||||
for block in source_blocks:
|
||||
block_id = block.block_id
|
||||
assert block_id is not None
|
||||
|
||||
refcount = self._refcounter.incr(block_id)
|
||||
assert refcount != 1, "can't fork free'd block_id = {}".format(
|
||||
block_id)
|
||||
|
||||
forked_block = self._block_pool.init_block(
|
||||
prev_block=prev_block,
|
||||
token_ids=block.token_ids,
|
||||
block_size=self._block_size,
|
||||
physical_block_id=block_id,
|
||||
extra_hash=block.extra_hash)
|
||||
|
||||
forked_blocks.append(forked_block)
|
||||
prev_block = forked_blocks[-1]
|
||||
|
||||
return forked_blocks
|
||||
|
||||
def get_num_free_blocks(self, seq_id: Optional[int] = None, device: Optional[Device] = None) -> int:
|
||||
assert device is None
|
||||
# The number of free blocks is the number of hashless free blocks
|
||||
# plus the number of blocks evictor could free from its list.
|
||||
return self._hashless_allocator.get_num_free_blocks(seq_id=seq_id
|
||||
) + self.evictor.num_blocks
|
||||
|
||||
def get_num_total_blocks(self, seq_id: Optional[int] = None) -> int:
|
||||
return self._hashless_allocator.get_num_total_blocks(seq_id=seq_id)
|
||||
|
||||
def get_physical_block_id(self, absolute_id: int) -> int:
|
||||
"""Returns the zero-offset block id on certain block allocator
|
||||
given the absolute block id.
|
||||
|
||||
Args:
|
||||
absolute_id (int): The absolute block id for the block
|
||||
in whole allocator.
|
||||
|
||||
Returns:
|
||||
int: The rzero-offset block id on certain device.
|
||||
"""
|
||||
return sorted(self.all_block_ids).index(absolute_id)
|
||||
|
||||
@property
|
||||
def all_block_ids(self) -> FrozenSet[int]:
|
||||
return self._hashless_allocator.all_block_ids
|
||||
|
||||
def get_prefix_cache_hit_rate(self) -> float:
|
||||
return self.metric_data.get_hit_rate()
|
||||
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
"""Reset prefix cache. This function may be used in RLHF
|
||||
flows to invalid prefix caching after the weights are updated,
|
||||
or used for resetting prefix caching status for benchmarking.
|
||||
|
||||
Returns:
|
||||
bool: True if the prefix cache is successfully reset,
|
||||
False otherwise.
|
||||
"""
|
||||
num_used_blocks = (self.get_num_total_blocks() -
|
||||
self.get_num_free_blocks())
|
||||
if num_used_blocks > 0:
|
||||
logger.warning(
|
||||
"Failed to reset prefix cache because some "
|
||||
"blocks (%d) are not freed yet", num_used_blocks)
|
||||
return False
|
||||
|
||||
# Free all blocks in the evictor.
|
||||
while (block_id :=
|
||||
self._maybe_allocate_evicted_block_id()) is not None:
|
||||
# TODO: Teng
|
||||
self._hashless_allocator.free_block_id(block_id)
|
||||
|
||||
# Should not have any cached blocks because all blocks are evicted.
|
||||
assert not self._cached_blocks
|
||||
|
||||
# Reset the evictor.
|
||||
self.evictor = make_evictor(self.eviction_policy)
|
||||
|
||||
# Reset the block tracker.
|
||||
for block_id in self._block_tracker:
|
||||
self._block_tracker[block_id] = BlockTracker()
|
||||
|
||||
# Reset the metrics.
|
||||
self.metric_data = CacheMetricData()
|
||||
|
||||
logger.info("Successfully reset prefix cache")
|
||||
return True
|
||||
|
||||
def is_block_cached(self, block: Block) -> bool:
|
||||
assert block.content_hash is not None
|
||||
return block.content_hash in self._cached_blocks
|
||||
|
||||
def promote_to_immutable_block(self, block: Block, seq_id: Optional[int] = None) -> BlockId:
|
||||
"""Once a mutable block is full, it can be promoted to an immutable
|
||||
block. This means that its content can be referenced by future blocks
|
||||
having the same prefix.
|
||||
|
||||
Note that if we already have a cached block with the same content, we
|
||||
will replace the newly-promoted block's mapping with the existing cached
|
||||
block id.
|
||||
|
||||
Args:
|
||||
block: The mutable block to be promoted.
|
||||
|
||||
Returns:
|
||||
BlockId: Either the original block index, or the block index of
|
||||
the previously cached block matching the same content.
|
||||
"""
|
||||
# Ensure block can be promoted
|
||||
assert block.content_hash is not None
|
||||
assert block.block_id is not None
|
||||
assert self._refcounter.get(block.block_id) > 0
|
||||
|
||||
if block.content_hash not in self._cached_blocks:
|
||||
# No cached content hash => Set this block as cached.
|
||||
# Note that this block cannot be marked as computed yet
|
||||
# because other sequences in the same batch cannot reuse
|
||||
# this block.
|
||||
self._cached_blocks[block.content_hash] = block.block_id
|
||||
# Mark this block as touched so that it can be marked as
|
||||
# computed after the entire batch of sequences are scheduled.
|
||||
self._touched_blocks.add(block.block_id)
|
||||
return block.block_id
|
||||
|
||||
# Reuse the cached content hash
|
||||
self._decr_refcount_hashless_block(block, seq_id=seq_id)
|
||||
block.block_id = self._cached_blocks[block.content_hash]
|
||||
|
||||
# Increment refcount of the cached block and (possibly) restore
|
||||
# it from the evictor.
|
||||
# Note that in this case, the block is marked as computed
|
||||
self._incr_refcount_cached_block(block)
|
||||
|
||||
return block.block_id
|
||||
|
||||
def cow_block_if_not_appendable(self, block: Block, seq_id: Optional[int] = None) -> BlockId:
|
||||
"""Performs a copy-on-write operation on the given block if it is not
|
||||
appendable.
|
||||
|
||||
Args:
|
||||
block (Block): The block to check for copy-on-write.
|
||||
|
||||
Returns:
|
||||
BlockId: The block index of the new block if a copy-on-write
|
||||
operation was performed, or the original block index if
|
||||
no copy-on-write was necessary.
|
||||
"""
|
||||
src_block_id = block.block_id
|
||||
assert src_block_id is not None
|
||||
|
||||
if self._cow_tracker.is_appendable(block):
|
||||
return src_block_id
|
||||
|
||||
self._free_block_id(block)
|
||||
trg_block_id = self._allocate_block_id()
|
||||
|
||||
self._cow_tracker.record_cow(src_block_id, trg_block_id)
|
||||
|
||||
return trg_block_id
|
||||
|
||||
def clear_copy_on_writes(self) -> List[Tuple[BlockId, BlockId]]:
|
||||
"""Returns the copy-on-write source->destination mapping and clears it.
|
||||
|
||||
Returns:
|
||||
List[Tuple[BlockId, BlockId]]: A list mapping source
|
||||
block indices to destination block indices.
|
||||
"""
|
||||
return self._cow_tracker.clear_cows()
|
||||
|
||||
def mark_blocks_as_accessed(self, block_ids: List[int],
|
||||
now: float) -> None:
|
||||
"""Mark blocks as accessed, used in prefix caching.
|
||||
|
||||
If the block is added into evictor, we need to update corresponding
|
||||
info in evictor's metadata.
|
||||
"""
|
||||
|
||||
for block_id in block_ids:
|
||||
if self._block_tracker[block_id].active:
|
||||
self._block_tracker[block_id].last_accessed = now
|
||||
elif block_id in self.evictor:
|
||||
self.evictor.update(block_id, now)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Mark block as accessed which is not belonged to GPU")
|
||||
|
||||
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
|
||||
# Mark all touched blocks as computed.
|
||||
for block_id in self._touched_blocks:
|
||||
self._block_tracker[block_id].computed = True
|
||||
self._touched_blocks.clear()
|
||||
|
||||
def _track_block_id(self, block_id: Optional[BlockId],
|
||||
computed: bool) -> None:
|
||||
assert block_id is not None
|
||||
self._block_tracker[block_id].enable()
|
||||
self._block_tracker[block_id].computed = computed
|
||||
|
||||
def _untrack_block_id(self, block_id: Optional[BlockId]) -> None:
|
||||
assert block_id is not None
|
||||
self._block_tracker[block_id].disable()
|
||||
|
||||
def block_is_computed(self, block_id: int) -> bool:
|
||||
if self._block_tracker[block_id].active:
|
||||
return self._block_tracker[block_id].computed
|
||||
else:
|
||||
return block_id in self.evictor
|
||||
|
||||
def get_common_computed_block_ids(
|
||||
self, computed_seq_block_ids: List[List[int]]) -> List[int]:
|
||||
"""Return the block ids that are common for a given sequence group.
|
||||
|
||||
Only those blocks that are immutable and already be marked
|
||||
compyted would be taken consideration.
|
||||
"""
|
||||
|
||||
# NOTE We exclude the last block to avoid the case where the entire
|
||||
# prompt is cached. This would cause erroneous behavior in model
|
||||
# runner.
|
||||
|
||||
# It returns a list of int although type annotation says list of string.
|
||||
if len(computed_seq_block_ids) == 1:
|
||||
return computed_seq_block_ids[0]
|
||||
|
||||
return commonprefix([
|
||||
ids for ids in computed_seq_block_ids # type: ignore
|
||||
if ids
|
||||
])
|
||||
|
||||
def get_num_full_blocks_touched(self, blocks: List[Block]) -> int:
|
||||
"""Returns the number of full blocks that will be touched by
|
||||
swapping in/out.
|
||||
|
||||
Args:
|
||||
blocks: List of blocks to be swapped.
|
||||
Returns:
|
||||
int: the number of full blocks that will be touched by
|
||||
swapping in/out the given blocks. Non full blocks are ignored
|
||||
when deciding the number of blocks to touch.
|
||||
"""
|
||||
num_touched_blocks: int = 0
|
||||
for block in blocks:
|
||||
# If the block has a match in the cache and the cached
|
||||
# block is not referenced, then we still count it as a
|
||||
# touched block
|
||||
if block.is_full and (not self.is_block_cached(block) or \
|
||||
(block.content_hash is not None and \
|
||||
self._cached_blocks[block.content_hash] in \
|
||||
self.evictor)):
|
||||
num_touched_blocks += 1
|
||||
return num_touched_blocks
|
||||
|
||||
def swap_out(self, blocks: List[Block]) -> None:
|
||||
"""Execute the swap out actions. Basically just free the
|
||||
given blocks.
|
||||
|
||||
Args:
|
||||
blocks: List of blocks to be swapped out.
|
||||
"""
|
||||
for block in blocks:
|
||||
self._free_block_id(block)
|
||||
|
||||
def swap_in(self, blocks: List[Block]) -> None:
|
||||
"""Execute the swap in actions. Change the block id from
|
||||
old allocator to current allocator for each block to finish
|
||||
the block table update.
|
||||
|
||||
Args:
|
||||
blocks: List of blocks to be swapped in.
|
||||
"""
|
||||
for block in blocks:
|
||||
# Here we allocate either immutable or mutable block and then
|
||||
# extract its block_id. Note that the block object is released
|
||||
# and the block_id is assigned to "block" to allow reusing the
|
||||
# existing "block" object
|
||||
if block.is_full:
|
||||
tmp_block = self.allocate_immutable_block(
|
||||
prev_block=block.prev_block,
|
||||
token_ids=block.token_ids,
|
||||
extra_hash=block.extra_hash)
|
||||
else:
|
||||
tmp_block = self.allocate_mutable_block(
|
||||
prev_block=block.prev_block, extra_hash=block.extra_hash)
|
||||
tmp_block.append_token_ids(block.token_ids)
|
||||
|
||||
block_id = tmp_block.block_id
|
||||
self._block_pool.free_block(tmp_block)
|
||||
|
||||
block.block_id = block_id # Assign block_id
|
||||
|
||||
def find_cached_blocks_prefix(self, block_hashes: List[int]) -> List[int]:
|
||||
"""
|
||||
Given a list of block hashes, return the prefix of the block hashes that
|
||||
are all cached.
|
||||
|
||||
Since a block's block hash includes the hashes of all previous blocks,
|
||||
and we only allocate/deallocate blocks in the entire sequence, so if a
|
||||
block is cached, then all previous blocks are also cached. With this
|
||||
property, we can use binary search to find the prefix of cached blocks.
|
||||
|
||||
Args:
|
||||
block_hashes (List[int]): The list of block hashes.
|
||||
|
||||
Returns:
|
||||
List[int]: The prefix of the `block_hashes` that are cached.
|
||||
"""
|
||||
|
||||
def _block_is_cached(block_hash: PrefixHash) -> bool:
|
||||
if block_hash not in self._cached_blocks:
|
||||
return False
|
||||
|
||||
cached_block_id = self._cached_blocks[block_hash]
|
||||
# We only consider the blocks that are marked as computed.
|
||||
return self.block_is_computed(cached_block_id)
|
||||
|
||||
def _bisect_left(a, x, key: Callable[[PrefixHash], bool]) -> int:
|
||||
|
||||
# python <= 3.10 don't have the key argument
|
||||
if sys.version_info < (3, 10):
|
||||
a = [key(e) for e in a]
|
||||
return bisect_left(a, x)
|
||||
else:
|
||||
return bisect_left(a, x, key=key)
|
||||
|
||||
# Look for the first block that's not cached, and returns the prefix
|
||||
# i.e. blocks that are cached.
|
||||
idx = _bisect_left(block_hashes,
|
||||
True,
|
||||
key=lambda x: not _block_is_cached(x))
|
||||
return block_hashes[:idx]
|
||||
|
||||
|
||||
class PrefixCachingBlock(Block):
|
||||
"""A block implementation that supports prefix caching.
|
||||
|
||||
The PrefixCachingBlock class represents a block of token IDs with prefix
|
||||
caching capabilities. It wraps a NaiveBlock internally and provides
|
||||
additional functionality for content hashing and promoting immutable blocks
|
||||
with the prefix caching allocator.
|
||||
|
||||
Args:
|
||||
prev_block (Optional[PrefixCachingBlock]): The previous block in the
|
||||
sequence.
|
||||
token_ids (List[int]): The initial token IDs to be stored in the block.
|
||||
block_size (int): The maximum number of token IDs that can be stored in
|
||||
the block.
|
||||
allocator (BlockAllocator): The prefix
|
||||
caching block allocator associated with this block.
|
||||
block_id (Optional[int], optional): The physical block index
|
||||
of this block. Defaults to None.
|
||||
extra_hash (Optional[int]): The hash value of additional factors
|
||||
such as adapters that influence the block, apart from the token_ids.
|
||||
"""
|
||||
|
||||
# Note that we use 'None' as a string here instead of None because
|
||||
# as of Python 3.12, hash(None) returns a constant predictable value.
|
||||
# This could possibly make it easier to find and exploit hash
|
||||
# collisions. 'None' as a string will be hashed differently per process,
|
||||
# but consistently within the same process. This is the same as the
|
||||
# behavior of None prior to Python 3.12.
|
||||
_none_hash: int = hash('None')
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prev_block: Optional[Block],
|
||||
token_ids: List[int],
|
||||
block_size: int,
|
||||
allocator: BlockAllocator,
|
||||
block_id: Optional[int] = None,
|
||||
computed: bool = False,
|
||||
extra_hash: Optional[int] = None,
|
||||
):
|
||||
assert isinstance(allocator, PrefixCachingBlockAllocator), (
|
||||
"Currently this class is only tested with "
|
||||
"PrefixCachingBlockAllocator. Got instead allocator = {}".format(
|
||||
allocator))
|
||||
assert_prefix_caching_block_or_none(prev_block)
|
||||
|
||||
self._prev_block = prev_block
|
||||
self._cached_content_hash: Optional[int] = None
|
||||
self._cached_num_tokens_total: int = 0
|
||||
self._allocator = allocator
|
||||
self._last_accessed: float = _DEFAULT_LAST_ACCESSED_TIME
|
||||
self._computed = computed
|
||||
self._extra_hash = extra_hash
|
||||
|
||||
# On the first time, we create the block object, and next we only
|
||||
# reinitialize it
|
||||
if hasattr(self, "_block"):
|
||||
self._block.__init__( # type: ignore[has-type]
|
||||
prev_block=prev_block,
|
||||
token_ids=token_ids,
|
||||
block_size=block_size,
|
||||
block_id=block_id,
|
||||
allocator=self._allocator)
|
||||
else:
|
||||
self._block = NaiveBlock(prev_block=prev_block,
|
||||
token_ids=token_ids,
|
||||
block_size=block_size,
|
||||
block_id=block_id,
|
||||
allocator=self._allocator)
|
||||
|
||||
self._update_num_tokens_total()
|
||||
|
||||
def _update_num_tokens_total(self):
|
||||
"""Incrementally computes the number of tokens that there is
|
||||
till the current block (included)
|
||||
"""
|
||||
res = 0
|
||||
|
||||
# Add all previous blocks
|
||||
if self._prev_block is not None:
|
||||
res += self._prev_block.num_tokens_total
|
||||
|
||||
# Add current block
|
||||
res += len(self.token_ids)
|
||||
|
||||
self._cached_num_tokens_total = res
|
||||
|
||||
@property
|
||||
def computed(self) -> bool:
|
||||
return self._computed
|
||||
|
||||
@computed.setter
|
||||
def computed(self, value) -> None:
|
||||
self._computed = value
|
||||
|
||||
@property
|
||||
def last_accessed(self) -> float:
|
||||
return self._last_accessed
|
||||
|
||||
@last_accessed.setter
|
||||
def last_accessed(self, last_accessed_ts: float):
|
||||
self._last_accessed = last_accessed_ts
|
||||
|
||||
def append_token_ids(self, token_ids: List[int], seq_id: Optional[int] = None) -> None:
|
||||
"""Appends the given token IDs to the block and registers the block as
|
||||
immutable if the block becomes full.
|
||||
|
||||
Args:
|
||||
token_ids (List[int]): The token IDs to be appended to the block.
|
||||
"""
|
||||
# Ensure this is mutable block (not promoted)
|
||||
assert self.content_hash is None
|
||||
assert not self.computed
|
||||
assert seq_id is not None
|
||||
|
||||
if len(token_ids) == 0:
|
||||
return
|
||||
|
||||
# Ensure there are input tokens
|
||||
assert token_ids, "Got token_ids = {}".format(token_ids)
|
||||
|
||||
# Naive block handles CoW.
|
||||
self._block.append_token_ids(token_ids, seq_id=seq_id)
|
||||
self._update_num_tokens_total()
|
||||
|
||||
# If the content hash is present, then the block can be made immutable.
|
||||
# Register ourselves with the allocator, potentially replacing the
|
||||
# physical block index.
|
||||
if self.content_hash is not None:
|
||||
self.block_id = self._allocator.promote_to_immutable_block(self, seq_id=seq_id)
|
||||
|
||||
@property
|
||||
def block_id(self) -> Optional[int]:
|
||||
return self._block.block_id
|
||||
|
||||
@block_id.setter
|
||||
def block_id(self, value) -> None:
|
||||
self._block.block_id = value
|
||||
|
||||
@property
|
||||
def is_full(self) -> bool:
|
||||
return self._block.is_full
|
||||
|
||||
@property
|
||||
def num_empty_slots(self) -> int:
|
||||
return self._block.num_empty_slots
|
||||
|
||||
@property
|
||||
def num_tokens_total(self) -> int:
|
||||
return self._cached_num_tokens_total
|
||||
|
||||
@property
|
||||
def block_size(self) -> int:
|
||||
return self._block.block_size
|
||||
|
||||
@property
|
||||
def token_ids(self) -> List[int]:
|
||||
return self._block.token_ids
|
||||
|
||||
@property
|
||||
def prev_block(self) -> Optional[Block]:
|
||||
return self._prev_block
|
||||
|
||||
@property
|
||||
def extra_hash(self) -> Optional[int]:
|
||||
return self._extra_hash
|
||||
|
||||
@property
|
||||
def content_hash(self) -> Optional[int]:
|
||||
"""Return the content-based hash of the current block, or None if it is
|
||||
not yet defined.
|
||||
|
||||
For the content-based hash to be defined, the current block must be
|
||||
full.
|
||||
"""
|
||||
# If the hash is already computed, return it.
|
||||
if self._cached_content_hash is not None:
|
||||
return self._cached_content_hash
|
||||
|
||||
# We cannot compute a hash for the current block because it is not full.
|
||||
if not self.is_full:
|
||||
return None
|
||||
|
||||
is_first_block = self._prev_block is None
|
||||
prev_block_hash = (
|
||||
self._none_hash if is_first_block else
|
||||
self._prev_block.content_hash # type: ignore
|
||||
)
|
||||
|
||||
# Previous block exists but does not yet have a hash.
|
||||
# Return no hash in this case.
|
||||
if prev_block_hash == self._none_hash and not is_first_block:
|
||||
return None
|
||||
|
||||
self._cached_content_hash = PrefixCachingBlock.hash_block_tokens(
|
||||
is_first_block,
|
||||
prev_block_hash,
|
||||
cur_block_token_ids=self.token_ids,
|
||||
extra_hash=self._extra_hash)
|
||||
return self._cached_content_hash
|
||||
|
||||
@classmethod
|
||||
def hash_block_tokens(cls,
|
||||
is_first_block: bool,
|
||||
prev_block_hash: Optional[int],
|
||||
cur_block_token_ids: List[int],
|
||||
extra_hash: Optional[int] = None) -> int:
|
||||
"""Computes a hash value corresponding to the contents of a block and
|
||||
the contents of the preceding block(s). The hash value is used for
|
||||
prefix caching.
|
||||
|
||||
Parameters:
|
||||
- is_first_block (bool): A flag indicating if the block is the first in
|
||||
the sequence.
|
||||
- prev_block_hash (Optional[int]): The hash of the previous block. None
|
||||
if this is the first block.
|
||||
- cur_block_token_ids (List[int]): A list of token ids in the current
|
||||
block. The current block is assumed to be full.
|
||||
- extra_hash (Optional[int]): The hash value of additional factors
|
||||
such as adapters that influence the block, apart from the token_ids.
|
||||
|
||||
Returns:
|
||||
- int: The computed hash value for the block.
|
||||
"""
|
||||
if is_first_block and prev_block_hash is None:
|
||||
prev_block_hash = cls._none_hash
|
||||
return hash((is_first_block, prev_block_hash, *cur_block_token_ids,
|
||||
extra_hash))
|
||||
575
vllm_vacc/vllm/core/block_manager.py
Normal file
575
vllm_vacc/vllm/core/block_manager.py
Normal file
@@ -0,0 +1,575 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""A block manager that manages token blocks."""
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Tuple
|
||||
|
||||
from vllm.core.block.block_table import BlockTable
|
||||
from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator
|
||||
from vllm.core.block.interfaces import Block
|
||||
from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker,
|
||||
LastAccessBlocksTracker)
|
||||
from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec
|
||||
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
|
||||
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
|
||||
from vllm.utils import Device
|
||||
import os
|
||||
SeqId = int
|
||||
EncoderSeqId = str
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from vllm_vacc.vllm.model_executor.models.vars import LLM_MAX_PREFILL_SEQ_LEN
|
||||
|
||||
max_seq_num = int(os.getenv("MAX_SEQ_NUM", 4))
|
||||
if max_seq_num not in [1, 2, 4]:
|
||||
max_seq_num = 4
|
||||
|
||||
from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size
|
||||
|
||||
class SelfAttnBlockSpaceManager(BlockSpaceManager):
|
||||
"""BlockSpaceManager which manages the allocation of KV cache.
|
||||
|
||||
It owns responsibility for allocation, swapping, allocating memory for
|
||||
autoregressively-generated tokens, and other advanced features such as
|
||||
prefix caching, forking/copy-on-write, and sliding-window memory allocation.
|
||||
|
||||
This class implements the design described in
|
||||
https://github.com/vllm-project/vllm/pull/3492.
|
||||
|
||||
Lookahead slots
|
||||
The block manager has the notion of a "lookahead slot". These are slots
|
||||
in the KV cache that are allocated for a sequence. Unlike the other
|
||||
allocated slots, the content of these slots is undefined -- the worker
|
||||
may use the memory allocations in any way.
|
||||
|
||||
In practice, a worker could use these lookahead slots to run multiple
|
||||
forward passes for a single scheduler invocation. Each successive
|
||||
forward pass would write KV activations to the corresponding lookahead
|
||||
slot. This allows low inter-token latency use-cases, where the overhead
|
||||
of continuous batching scheduling is amortized over >1 generated tokens.
|
||||
|
||||
Speculative decoding uses lookahead slots to store KV activations of
|
||||
proposal tokens.
|
||||
|
||||
See https://github.com/vllm-project/vllm/pull/3250 for more information
|
||||
on lookahead scheduling.
|
||||
|
||||
Args:
|
||||
block_size (int): The size of each memory block.
|
||||
num_gpu_blocks (int): The number of memory blocks allocated on GPU.
|
||||
num_cpu_blocks (int): The number of memory blocks allocated on CPU.
|
||||
watermark (float, optional): The threshold used for memory swapping.
|
||||
Defaults to 0.01.
|
||||
sliding_window (Optional[int], optional): The size of the sliding
|
||||
window. Defaults to None.
|
||||
enable_caching (bool, optional): Flag indicating whether caching is
|
||||
enabled. Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_size: int,
|
||||
num_gpu_blocks: int,
|
||||
num_cpu_blocks: int,
|
||||
watermark: float = 0.01,
|
||||
sliding_window: Optional[int] = None,
|
||||
enable_caching: bool = False,
|
||||
) -> None:
|
||||
self.block_size = block_size
|
||||
self.num_total_gpu_blocks = num_gpu_blocks
|
||||
self.num_total_cpu_blocks = num_cpu_blocks
|
||||
self.per_gpu_blocks = num_gpu_blocks // max_seq_num
|
||||
|
||||
self.sliding_window = sliding_window
|
||||
# max_block_sliding_window is the max number of blocks that need to be
|
||||
# allocated
|
||||
self.max_block_sliding_window = None
|
||||
if sliding_window is not None:
|
||||
# +1 here because // rounds down
|
||||
num_blocks = sliding_window // block_size + 1
|
||||
# +1 here because the last block may not be full,
|
||||
# and so the sequence stretches one more block at the beginning
|
||||
# For example, if sliding_window is 3 and block_size is 4,
|
||||
# we may need 2 blocks when the second block only holds 1 token.
|
||||
self.max_block_sliding_window = num_blocks + 1
|
||||
|
||||
self.watermark = watermark
|
||||
assert watermark >= 0.0
|
||||
|
||||
self.enable_caching = enable_caching
|
||||
|
||||
# self.watermark_blocks = 1 # for test
|
||||
self.watermark_blocks = int(watermark * self.per_gpu_blocks)
|
||||
|
||||
self.block_allocator = CpuGpuBlockAllocator.create(
|
||||
allocator_type="prefix_caching" if enable_caching else "naive",
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=num_cpu_blocks,
|
||||
block_size=block_size,
|
||||
)
|
||||
|
||||
self.block_tables: Dict[SeqId, BlockTable] = {}
|
||||
self.cross_block_tables: Dict[EncoderSeqId, BlockTable] = {}
|
||||
|
||||
self._computed_blocks_tracker = ComputedBlocksTracker(
|
||||
self.block_allocator, self.block_size, self.enable_caching)
|
||||
self._last_access_blocks_tracker = LastAccessBlocksTracker(
|
||||
self.block_allocator)
|
||||
|
||||
def can_allocate(self,
|
||||
seq_group: SequenceGroup,
|
||||
num_lookahead_slots: int = 0) -> AllocStatus:
|
||||
# FIXME(woosuk): Here we assume that all sequences in the group share
|
||||
# the same prompt. This may not be true for preempted sequences.
|
||||
|
||||
check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group)
|
||||
|
||||
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
|
||||
num_required_blocks = BlockTable.get_num_required_blocks(
|
||||
seq.get_token_ids(),
|
||||
block_size=self.block_size,
|
||||
num_lookahead_slots=num_lookahead_slots,
|
||||
)
|
||||
|
||||
if seq_group.is_encoder_decoder():
|
||||
encoder_seq = seq_group.get_encoder_seq()
|
||||
assert encoder_seq is not None
|
||||
num_required_blocks += BlockTable.get_num_required_blocks(
|
||||
encoder_seq.get_token_ids(),
|
||||
block_size=self.block_size,
|
||||
)
|
||||
|
||||
if self.max_block_sliding_window is not None:
|
||||
num_required_blocks = min(num_required_blocks,
|
||||
self.max_block_sliding_window)
|
||||
|
||||
# TODO
|
||||
# limitations to be removed later
|
||||
required_size = num_required_blocks * self.block_size
|
||||
if required_size > LLM_MAX_PREFILL_SEQ_LEN:
|
||||
logging.warning(
|
||||
f"This model's maximum input seq length limit is "
|
||||
f"{LLM_MAX_PREFILL_SEQ_LEN} tokens. However, you requested "
|
||||
f"({required_size} in the input messages, "
|
||||
f"Please reduce the length of the input messages.")
|
||||
return AllocStatus.NEVER
|
||||
|
||||
# Use watermark to avoid frequent cache eviction.
|
||||
# NOTE:
|
||||
# num of the gpu blocks for each seq_id might not be the same
|
||||
# since each seq can use different blk group number
|
||||
|
||||
total_gpu_blocks = self.block_allocator.get_num_total_blocks(device=Device.GPU, seq_id=seq.seq_id)
|
||||
if (total_gpu_blocks
|
||||
- num_required_blocks < self.watermark_blocks):
|
||||
self.block_allocator.get_num_total_blocks(device=Device.GPU, seq_id=seq.seq_id)
|
||||
return AllocStatus.NEVER
|
||||
|
||||
# NOTE: num_required_blocks should be up aligned to 8K beforce compare
|
||||
block_num = env_blk_grp_size // self.block_size
|
||||
if num_required_blocks % block_num: # align
|
||||
num_required_blocks = (num_required_blocks // block_num + 1) * block_num
|
||||
|
||||
if total_gpu_blocks % block_num:
|
||||
total_gpu_blocks = (total_gpu_blocks // block_num) * block_num
|
||||
# Use the aligned memory size to decide whether to reject the request.
|
||||
if total_gpu_blocks - num_required_blocks < self.watermark_blocks:
|
||||
logging.warning("gpu memory may not enough, please try shorter sequence")
|
||||
return AllocStatus.NEVER
|
||||
|
||||
num_free_gpu_blocks = self.block_allocator.get_num_free_blocks(
|
||||
device=Device.GPU, seq_id=seq.seq_id)
|
||||
# logging.warning(f"free blocks: {num_free_gpu_blocks} required: {num_required_blocks} watermark: {self.watermark_blocks}")
|
||||
|
||||
if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks:
|
||||
return AllocStatus.OK
|
||||
else:
|
||||
# logging.warning(f"free_blocks:{num_free_gpu_blocks} "
|
||||
# f"required_blocks:{num_required_blocks} "
|
||||
# f"watermark:{self.watermark_blocks} "
|
||||
# f"allocate seq:{seq.seq_id} later")
|
||||
return AllocStatus.LATER
|
||||
|
||||
|
||||
def _allocate_sequence(self, seq: Sequence) -> BlockTable:
|
||||
block_table = BlockTable(
|
||||
block_size=self.block_size,
|
||||
block_allocator=self.block_allocator,
|
||||
max_block_sliding_window=self.max_block_sliding_window,
|
||||
)
|
||||
if seq.get_token_ids():
|
||||
# NOTE: If there are any factors affecting the block besides
|
||||
# token_ids, they should be added as input to extra_hash.
|
||||
extra_hash = seq.extra_hash()
|
||||
|
||||
# Add blocks to the block table only if the sequence is non empty.
|
||||
block_table.allocate(token_ids=seq.get_token_ids(),
|
||||
extra_hash=extra_hash,
|
||||
seq_id=seq.seq_id)
|
||||
|
||||
return block_table
|
||||
|
||||
def allocate(self, seq_group: SequenceGroup) -> None:
|
||||
|
||||
# Allocate self-attention block tables for decoder sequences
|
||||
waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
|
||||
assert not (set(seq.seq_id for seq in waiting_seqs)
|
||||
& self.block_tables.keys()), "block table already exists"
|
||||
|
||||
# NOTE: Here we assume that all sequences in the group have the same
|
||||
# prompt.
|
||||
seq = waiting_seqs[0]
|
||||
block_table: BlockTable = self._allocate_sequence(seq)
|
||||
self.block_tables[seq.seq_id] = block_table
|
||||
|
||||
# Track seq
|
||||
self._last_access_blocks_tracker.add_seq(seq.seq_id)
|
||||
|
||||
# Assign the block table for each sequence.
|
||||
for seq in waiting_seqs[1:]:
|
||||
self.block_tables[seq.seq_id] = block_table.fork()
|
||||
|
||||
# Track seq
|
||||
self._last_access_blocks_tracker.add_seq(seq.seq_id)
|
||||
|
||||
# Allocate cross-attention block table for encoder sequence
|
||||
#
|
||||
# NOTE: Here we assume that all sequences in the group have the same
|
||||
# encoder prompt.
|
||||
request_id = seq_group.request_id
|
||||
|
||||
assert (request_id
|
||||
not in self.cross_block_tables), \
|
||||
"block table already exists"
|
||||
|
||||
check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group)
|
||||
|
||||
if seq_group.is_encoder_decoder():
|
||||
encoder_seq = seq_group.get_encoder_seq()
|
||||
assert encoder_seq is not None
|
||||
block_table = self._allocate_sequence(encoder_seq)
|
||||
self.cross_block_tables[request_id] = block_table
|
||||
|
||||
def can_append_slots(self, seq_group: SequenceGroup,
|
||||
num_lookahead_slots: int) -> bool:
|
||||
"""Determine if there is enough space in the GPU KV cache to continue
|
||||
generation of the specified sequence group.
|
||||
|
||||
We use a worst-case heuristic: assume each touched block will require a
|
||||
new allocation (either via CoW or new block). We can append slots if the
|
||||
number of touched blocks is less than the number of free blocks.
|
||||
|
||||
"Lookahead slots" are slots that are allocated in addition to the slots
|
||||
for known tokens. The contents of the lookahead slots are not defined.
|
||||
This is used by speculative decoding when speculating future tokens.
|
||||
"""
|
||||
|
||||
num_touched_blocks = 0
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
|
||||
num_touched_blocks += (
|
||||
block_table.get_num_blocks_touched_by_append_slots(
|
||||
token_ids=block_table.get_unseen_token_ids(
|
||||
seq.get_token_ids()),
|
||||
num_lookahead_slots=num_lookahead_slots,
|
||||
))
|
||||
|
||||
num_free_gpu_blocks = self.block_allocator.get_num_free_blocks(
|
||||
Device.GPU, seq_id=seq.seq_id)
|
||||
# NOTE: if False, trigger RECOMPUTE
|
||||
return num_touched_blocks <= num_free_gpu_blocks
|
||||
|
||||
def append_slots(
|
||||
self,
|
||||
seq: Sequence,
|
||||
num_lookahead_slots: int,
|
||||
) -> List[Tuple[int, int]]:
|
||||
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
|
||||
block_table.append_token_ids(
|
||||
token_ids=block_table.get_unseen_token_ids(seq.get_token_ids()),
|
||||
num_lookahead_slots=num_lookahead_slots,
|
||||
num_computed_slots=seq.data.get_num_computed_tokens(),
|
||||
extra_hash=seq.extra_hash(),
|
||||
seq_id = seq.seq_id
|
||||
)
|
||||
# Return any new copy-on-writes.
|
||||
new_cows = self.block_allocator.clear_copy_on_writes(seq.seq_id)
|
||||
return new_cows
|
||||
|
||||
def free(self, seq: Sequence) -> None:
|
||||
seq_id = seq.seq_id
|
||||
|
||||
if seq_id not in self.block_tables:
|
||||
# Already freed or haven't been scheduled yet.
|
||||
return
|
||||
|
||||
# Update seq block ids with the latest access time
|
||||
self._last_access_blocks_tracker.update_seq_blocks_last_access(
|
||||
seq_id, self.block_tables[seq.seq_id].physical_block_ids)
|
||||
|
||||
# Untrack seq
|
||||
self._last_access_blocks_tracker.remove_seq(seq_id)
|
||||
self._computed_blocks_tracker.remove_seq(seq_id)
|
||||
|
||||
# Free table/blocks
|
||||
self.block_tables[seq_id].free(seq_id)
|
||||
del self.block_tables[seq_id]
|
||||
|
||||
def remove_seq_from_computed_blocks_tracker(self, seq: Sequence) -> None:
|
||||
seq_id = seq.seq_id
|
||||
self._computed_blocks_tracker.remove_seq(seq_id)
|
||||
|
||||
def free_cross(self, seq_group: SequenceGroup) -> None:
|
||||
request_id = seq_group.request_id
|
||||
if request_id not in self.cross_block_tables:
|
||||
# Already freed or hasn't been scheduled yet.
|
||||
return
|
||||
self.cross_block_tables[request_id].free()
|
||||
del self.cross_block_tables[request_id]
|
||||
|
||||
def get_block_table(self, seq: Sequence) -> List[int]:
|
||||
block_ids = self.block_tables[seq.seq_id].physical_block_ids
|
||||
return block_ids # type: ignore
|
||||
|
||||
def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]:
|
||||
request_id = seq_group.request_id
|
||||
assert request_id in self.cross_block_tables
|
||||
block_ids = self.cross_block_tables[request_id].physical_block_ids
|
||||
assert all(b is not None for b in block_ids)
|
||||
return block_ids # type: ignore
|
||||
|
||||
def access_all_blocks_in_seq(self, seq: Sequence, now: float):
|
||||
if self.enable_caching:
|
||||
# Record the latest access time for the sequence. The actual update
|
||||
# of the block ids is deferred to the sequence free(..) call, since
|
||||
# only during freeing of block ids, the blocks are actually added to
|
||||
# the evictor (which is when the most updated time is required)
|
||||
# (This avoids expensive calls to mark_blocks_as_accessed(..))
|
||||
self._last_access_blocks_tracker.update_last_access(
|
||||
seq.seq_id, now)
|
||||
|
||||
def mark_blocks_as_computed(self, seq_group: SequenceGroup,
|
||||
token_chunk_size: int):
|
||||
# If prefix caching is enabled, mark immutable blocks as computed
|
||||
# right after they have been scheduled (for prefill). This assumes
|
||||
# the scheduler is synchronous so blocks are actually computed when
|
||||
# scheduling the next batch.
|
||||
self.block_allocator.mark_blocks_as_computed([])
|
||||
|
||||
def get_common_computed_block_ids(
|
||||
self, seqs: List[Sequence]) -> GenericSequence[int]:
|
||||
"""Determine which blocks for which we skip prefill.
|
||||
|
||||
With prefix caching we can skip prefill for previously-generated blocks.
|
||||
Currently, the attention implementation only supports skipping cached
|
||||
blocks if they are a contiguous prefix of cached blocks.
|
||||
|
||||
This method determines which blocks can be safely skipped for all
|
||||
sequences in the sequence group.
|
||||
"""
|
||||
computed_seq_block_ids = []
|
||||
for seq in seqs:
|
||||
all_blocks = self.block_tables[seq.seq_id].physical_block_ids
|
||||
num_cached_tokens = (
|
||||
self._computed_blocks_tracker.get_num_cached_tokens(seq))
|
||||
assert num_cached_tokens % self.block_size == 0
|
||||
num_cached_blocks = num_cached_tokens // self.block_size
|
||||
computed_block_ids = all_blocks[:num_cached_blocks]
|
||||
computed_seq_block_ids.append(computed_block_ids)
|
||||
|
||||
# NOTE(sang): This assumes seq_block_ids doesn't contain any None.
|
||||
return self.block_allocator.get_common_computed_block_ids(
|
||||
computed_seq_block_ids) # type: ignore
|
||||
|
||||
def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
|
||||
if parent_seq.seq_id not in self.block_tables:
|
||||
# Parent sequence has either been freed or never existed.
|
||||
return
|
||||
src_block_table = self.block_tables[parent_seq.seq_id]
|
||||
self.block_tables[child_seq.seq_id] = src_block_table.fork()
|
||||
|
||||
# Track child seq
|
||||
self._last_access_blocks_tracker.add_seq(child_seq.seq_id)
|
||||
|
||||
def can_swap_in(self, seq_group: SequenceGroup,
|
||||
num_lookahead_slots: int) -> AllocStatus:
|
||||
"""Returns the AllocStatus for the given sequence_group
|
||||
with num_lookahead_slots.
|
||||
|
||||
Args:
|
||||
sequence_group (SequenceGroup): The sequence group to swap in.
|
||||
num_lookahead_slots (int): Number of lookahead slots used in
|
||||
speculative decoding, default to 0.
|
||||
|
||||
Returns:
|
||||
AllocStatus: The AllocStatus for the given sequence group.
|
||||
"""
|
||||
return self._can_swap(seq_group, Device.GPU, SequenceStatus.SWAPPED,
|
||||
num_lookahead_slots)
|
||||
|
||||
def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
|
||||
"""Returns the block id mapping (from CPU to GPU) generated by
|
||||
swapping in the given seq_group with num_lookahead_slots.
|
||||
|
||||
Args:
|
||||
seq_group (SequenceGroup): The sequence group to swap in.
|
||||
|
||||
Returns:
|
||||
List[Tuple[int, int]]: The mapping of swapping block from CPU
|
||||
to GPU.
|
||||
"""
|
||||
physical_block_id_mapping = []
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
|
||||
blocks = self.block_tables[seq.seq_id].blocks
|
||||
if len(blocks) == 0:
|
||||
continue
|
||||
|
||||
seq_swap_mapping = self.block_allocator.swap(blocks=blocks,
|
||||
src_device=Device.CPU,
|
||||
dst_device=Device.GPU)
|
||||
|
||||
# Refresh the block ids of the table (post-swap)
|
||||
self.block_tables[seq.seq_id].update(blocks)
|
||||
|
||||
seq_physical_block_id_mapping = {
|
||||
self.block_allocator.get_physical_block_id(
|
||||
Device.CPU, cpu_block_id):
|
||||
self.block_allocator.get_physical_block_id(
|
||||
Device.GPU, gpu_block_id)
|
||||
for cpu_block_id, gpu_block_id in seq_swap_mapping.items()
|
||||
}
|
||||
|
||||
physical_block_id_mapping.extend(
|
||||
list(seq_physical_block_id_mapping.items()))
|
||||
|
||||
return physical_block_id_mapping
|
||||
|
||||
def can_swap_out(self, seq_group: SequenceGroup) -> bool:
|
||||
"""Returns whether we can swap out the given sequence_group
|
||||
with num_lookahead_slots.
|
||||
|
||||
Args:
|
||||
seq_group (SequenceGroup): The sequence group to swap out.
|
||||
num_lookahead_slots (int): Number of lookahead slots used in
|
||||
speculative decoding, default to 0.
|
||||
|
||||
Returns:
|
||||
bool: Whether it's possible to swap out current sequence group.
|
||||
"""
|
||||
alloc_status = self._can_swap(seq_group, Device.CPU,
|
||||
SequenceStatus.RUNNING)
|
||||
return alloc_status == AllocStatus.OK
|
||||
|
||||
def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
|
||||
"""Returns the block id mapping (from GPU to CPU) generated by
|
||||
swapping out the given sequence_group with num_lookahead_slots.
|
||||
|
||||
Args:
|
||||
sequence_group (SequenceGroup): The sequence group to swap out.
|
||||
|
||||
Returns:
|
||||
List[Tuple[int, int]]: The mapping of swapping block from
|
||||
GPU to CPU.
|
||||
"""
|
||||
physical_block_id_mapping = []
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
blocks = self.block_tables[seq.seq_id].blocks
|
||||
if len(blocks) == 0:
|
||||
continue
|
||||
|
||||
seq_swap_mapping = self.block_allocator.swap(blocks=blocks,
|
||||
src_device=Device.GPU,
|
||||
dst_device=Device.CPU)
|
||||
|
||||
# Refresh the block ids of the table (post-swap)
|
||||
self.block_tables[seq.seq_id].update(blocks)
|
||||
|
||||
seq_physical_block_id_mapping = {
|
||||
self.block_allocator.get_physical_block_id(
|
||||
Device.GPU, gpu_block_id):
|
||||
self.block_allocator.get_physical_block_id(
|
||||
Device.CPU, cpu_block_id)
|
||||
for gpu_block_id, cpu_block_id in seq_swap_mapping.items()
|
||||
}
|
||||
|
||||
physical_block_id_mapping.extend(
|
||||
list(seq_physical_block_id_mapping.items()))
|
||||
|
||||
return physical_block_id_mapping
|
||||
|
||||
def get_num_free_gpu_blocks(self) -> int:
|
||||
return self.block_allocator.get_num_free_blocks(Device.GPU)
|
||||
|
||||
def get_num_free_cpu_blocks(self) -> int:
|
||||
return self.block_allocator.get_num_free_blocks(Device.CPU)
|
||||
|
||||
def get_prefix_cache_hit_rate(self, device: Device) -> float:
|
||||
return self.block_allocator.get_prefix_cache_hit_rate(device)
|
||||
|
||||
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
|
||||
return self.block_allocator.reset_prefix_cache(device)
|
||||
|
||||
def _can_swap(self,
|
||||
seq_group: SequenceGroup,
|
||||
device: Device,
|
||||
status: SequenceStatus,
|
||||
num_lookahead_slots: int = 0) -> AllocStatus:
|
||||
"""Returns the AllocStatus for swapping in/out the given sequence_group
|
||||
on to the 'device'.
|
||||
|
||||
Args:
|
||||
sequence_group (SequenceGroup): The sequence group to swap in/out.
|
||||
device (Device): device to swap the 'seq_group' on.
|
||||
status (SequenceStatus): The status of sequence which is needed
|
||||
for action. RUNNING for swap out and SWAPPED for swap in
|
||||
num_lookahead_slots (int): Number of lookahead slots used in
|
||||
speculative decoding, default to 0.
|
||||
|
||||
Returns:
|
||||
AllocStatus: The AllocStatus for swapping in/out the given
|
||||
sequence_group on to the 'device'.
|
||||
"""
|
||||
# First determine the number of blocks that will be touched by this
|
||||
# swap. Then verify if there are available blocks in the device
|
||||
# to perform the swap.
|
||||
num_blocks_touched = 0
|
||||
blocks: List[Block] = []
|
||||
for seq in seq_group.get_seqs(status=status):
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
if block_table.blocks is not None:
|
||||
# Compute the number blocks to touch for the tokens to be
|
||||
# appended. This does NOT include the full blocks that need
|
||||
# to be touched for the swap.
|
||||
num_blocks_touched += \
|
||||
block_table.get_num_blocks_touched_by_append_slots(
|
||||
block_table.get_unseen_token_ids(seq.get_token_ids()),
|
||||
num_lookahead_slots=num_lookahead_slots)
|
||||
blocks.extend(block_table.blocks)
|
||||
# Compute the number of full blocks to touch and add it to the
|
||||
# existing count of blocks to touch.
|
||||
num_blocks_touched += self.block_allocator.get_num_full_blocks_touched(
|
||||
blocks, device=device)
|
||||
|
||||
watermark_blocks = 0
|
||||
if device == Device.GPU:
|
||||
watermark_blocks = self.watermark_blocks
|
||||
|
||||
if self.block_allocator.get_num_total_blocks(
|
||||
device) < num_blocks_touched:
|
||||
return AllocStatus.NEVER
|
||||
elif self.block_allocator.get_num_free_blocks(
|
||||
device) - num_blocks_touched >= watermark_blocks:
|
||||
return AllocStatus.OK
|
||||
else:
|
||||
return AllocStatus.LATER
|
||||
|
||||
def get_num_cached_tokens(self, seq: Sequence) -> int:
|
||||
"""Get the number of tokens in blocks that are already computed and
|
||||
cached in the block manager for the sequence.
|
||||
"""
|
||||
return self._computed_blocks_tracker.get_num_cached_tokens(seq)
|
||||
0
vllm_vacc/vllm/distributed/__init__.py
Normal file
0
vllm_vacc/vllm/distributed/__init__.py
Normal file
BIN
vllm_vacc/vllm/distributed/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/distributed/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
23
vllm_vacc/vllm/distributed/communication_op.py
Normal file
23
vllm_vacc/vllm/distributed/communication_op.py
Normal file
@@ -0,0 +1,23 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
def tensor_model_parallel_all_reduce_with_odsp(input_: torch.Tensor) -> torch.Tensor:
|
||||
"""All-reduce the input tensor across model parallel group."""
|
||||
from vllm.distributed import get_tp_group
|
||||
try:
|
||||
total_bytes = input_.numel() * input_.element_size() * get_tp_group().world_size
|
||||
# only support 4M now
|
||||
if total_bytes < 4194304:
|
||||
from torch_vacc.vacc import all_reduce
|
||||
return all_reduce(input_,
|
||||
get_tp_group().rank_in_group,
|
||||
get_tp_group().world_size,
|
||||
get_tp_group().group_id,
|
||||
dev_info = get_tp_group().rank_device_infos)
|
||||
except Exception as e:
|
||||
print("all_reduce by DSP run Fail, now use vccl-ops", e, input_.shape, input_.dtype)
|
||||
|
||||
return get_tp_group().all_reduce(input_)
|
||||
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,24 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
def all_gather_into_tensor(self, input_: torch.Tensor, dim: int = -1, output_tensor: torch.Tensor = None) -> torch.Tensor:
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
input_size = input_.size()
|
||||
# NOTE: we have to use concat-style all-gather here,
|
||||
# stack-style all-gather has compatibility issues with
|
||||
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
|
||||
output_size = (input_size[0] * self.world_size, ) + input_size[1:]
|
||||
# Allocate output tensor.
|
||||
# [N,] => [N*world_size], 1D Tensor
|
||||
if output_tensor is None:
|
||||
output_tensor = torch.empty(output_size,
|
||||
dtype=input_.dtype,
|
||||
device=input_.device)
|
||||
# print("o tensor is:", output_tensor.shape, "i tensor is:", input_.shape, input_size)
|
||||
# All-gather.
|
||||
dist.all_gather_into_tensor(output_tensor,
|
||||
input_,
|
||||
group=self.device_group)
|
||||
return output_tensor
|
||||
485
vllm_vacc/vllm/distributed/parallel_state.py
Normal file
485
vllm_vacc/vllm/distributed/parallel_state.py
Normal file
@@ -0,0 +1,485 @@
|
||||
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import supports_custom_op
|
||||
|
||||
from collections import namedtuple
|
||||
from typing import (Any, Dict, List, Optional, Tuple,
|
||||
Union)
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from vllm.utils import supports_custom_op
|
||||
|
||||
from vllm.distributed.parallel_state import TensorMetadata
|
||||
|
||||
# memory recycler
|
||||
MEMORY_RECYCLER_KEY = ['previous_hidden_states']
|
||||
|
||||
def _split_tensor_dict_concat(
|
||||
tensor_dict: Dict[str, Union[torch.Tensor, Any]]
|
||||
) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
|
||||
"""Split the tensor dictionary into two parts:
|
||||
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
|
||||
by its metadata.
|
||||
2. A list of tensors.
|
||||
"""
|
||||
|
||||
# all_tensor_list = ['input_tokens','input_positions', 'slot_mapping','seq_lens_tensor', 'context_lens_tensor','block_tables', 'query_start_loc','seq_start_loc', 'selected_token_indices']
|
||||
metadata_list: List[Tuple[str, Any]] = []
|
||||
tensor_list: List[torch.Tensor] = []
|
||||
all_tensor = []
|
||||
all_tensor_numel = 0
|
||||
for key, value in tensor_dict.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
# Note: we cannot use `value.device` here,
|
||||
# because it contains not only the device type but also the device
|
||||
# index (e.g. "cuda:0"). We only need the device type.
|
||||
# receiving side will set the device index.
|
||||
device = value.device.type
|
||||
|
||||
if not value.is_cpu and value.numel() > 0:
|
||||
value_bytes_tensor = value.view(torch.int8)
|
||||
all_tensor.append(value_bytes_tensor.view([-1]))
|
||||
all_tensor_numel += value_bytes_tensor.numel()
|
||||
# tensor_list.append(value)
|
||||
|
||||
metadata_list.append(
|
||||
(key, TensorMetadata(device, value.dtype, value.size())))
|
||||
|
||||
else:
|
||||
metadata_list.append((key, value))
|
||||
if len(all_tensor) != 0:
|
||||
|
||||
memory_recycler_dynamic_output = None
|
||||
# 计算all_tensor的总大小
|
||||
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler, DeepseekMTPMemoryRecycler
|
||||
if isinstance(memory_recycler, DeepseekMTPMemoryRecycler):
|
||||
memory_recycler_dynamic_output = memory_recycler.DYNAMIC_OUTPUT_BUFFER.view(torch.int8)[:all_tensor_numel]
|
||||
|
||||
if memory_recycler_dynamic_output is not None:
|
||||
all_tensor = torch.concatenate(all_tensor, 0, out = memory_recycler_dynamic_output)
|
||||
else:
|
||||
all_tensor = torch.concatenate(all_tensor, 0)
|
||||
|
||||
tensor_list.append(all_tensor)
|
||||
metadata_list.append(("all_tensor", TensorMetadata(all_tensor.device.type, all_tensor.dtype, all_tensor.size())))
|
||||
|
||||
return metadata_list, tensor_list
|
||||
|
||||
def all_gather_to_rank0(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
world_size = self.world_size
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
assert -input_.dim() <= dim < input_.dim(), (
|
||||
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
||||
|
||||
# For TPUs, use TPU communicator.
|
||||
tpu_comm = self.tpu_communicator
|
||||
if tpu_comm is not None and not tpu_comm.disabled:
|
||||
return tpu_comm.all_gather(input_, dim)
|
||||
|
||||
# For HPUs, use HPU communicator.
|
||||
hpu_comm = self.hpu_communicator
|
||||
if hpu_comm is not None and not hpu_comm.disabled:
|
||||
return hpu_comm.all_gather(input_, dim)
|
||||
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
input_size = input_.size()
|
||||
# NOTE: we have to use concat-style all-gather here,
|
||||
# stack-style all-gather has compatibility issues with
|
||||
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
|
||||
output_size = (input_size[0] * world_size, ) + input_size[1:]
|
||||
|
||||
try:
|
||||
total_bytes = input_.numel() * input_.element_size() * world_size
|
||||
# only support 4M now
|
||||
if total_bytes < 4194304:
|
||||
from torch_vacc.vacc.custom_ops import all_gather
|
||||
output_tensor = all_gather(input_, self.rank_in_group, self.world_size, self.group_id,
|
||||
dev_info = self.rank_device_infos)
|
||||
|
||||
if self.rank_in_group != 0:
|
||||
output_tensor = None
|
||||
else:
|
||||
output_tensor = output_tensor.movedim(0, dim)
|
||||
output_tensor = output_tensor.reshape(input_size[:dim] +
|
||||
(world_size *
|
||||
input_size[dim], ) +
|
||||
input_size[dim + 1:])
|
||||
|
||||
return output_tensor
|
||||
except Exception as e:
|
||||
print("all_gather by DSP run Fail, now use vccl-ops", e, input_.shape, input_.dtype)
|
||||
|
||||
# Allocate output tensor.
|
||||
output_tensor = torch.empty(output_size,
|
||||
dtype=input_.dtype,
|
||||
device=input_.device)
|
||||
# All-gather.
|
||||
torch.distributed.all_gather_into_tensor(output_tensor,
|
||||
input_,
|
||||
group=self.device_group)
|
||||
if self.rank_in_group != 0:
|
||||
output_tensor = None
|
||||
else:
|
||||
# Reshape
|
||||
output_tensor = output_tensor.reshape((world_size, ) + input_size)
|
||||
output_tensor = output_tensor.movedim(0, dim)
|
||||
output_tensor = output_tensor.reshape(input_size[:dim] +
|
||||
(world_size *
|
||||
input_size[dim], ) +
|
||||
input_size[dim + 1:])
|
||||
return output_tensor
|
||||
|
||||
def generate_group_id(self, group_id):
|
||||
self.group_id = group_id
|
||||
|
||||
def generate_rank_device_infos(self):
|
||||
import numpy as np
|
||||
import os
|
||||
# encoder rank_dev_list
|
||||
def combine_arrays(a, b):
|
||||
a = np.asarray(a, dtype=np.uint32)
|
||||
b = np.asarray(b, dtype=np.uint32)
|
||||
|
||||
if len(a) != len(b):
|
||||
raise ValueError("两个数组的长度必须一致。")
|
||||
|
||||
a_shifted = np.left_shift(a, 16)
|
||||
combined = np.bitwise_or(a_shifted, b)
|
||||
return combined.tolist()
|
||||
|
||||
# decoder rank_dev_list
|
||||
def uncombine_array(array):
|
||||
array = np.asarray(array, dtype=np.uint32)
|
||||
o_0 = array >> 16
|
||||
o_1 = array << 16 >> 16
|
||||
return o_0, o_1
|
||||
|
||||
physical_devices = self.ranks
|
||||
visible_devices = os.getenv('VACC_VISIBLE_DEVICES')
|
||||
|
||||
if visible_devices is not None:
|
||||
device_list = visible_devices.split(',')
|
||||
device_count = len(device_list)
|
||||
assert device_count >= len(self.ranks), f'VACC_VISIBLE_DEVICES:{device_count} is less than ranks:{len(self.ranks)}, please designate more devices'
|
||||
physical_devices = [int(device_list[i]) for i in self.ranks]
|
||||
# print("[vccl] logic_devices:physical_devices ", self.ranks, physical_devices)
|
||||
|
||||
logic_ranks = [self.ranks.index(rank) for rank in self.ranks]
|
||||
self.rank_device_infos = combine_arrays(logic_ranks, physical_devices)
|
||||
|
||||
def get_bitwidth(dtype):
|
||||
if dtype.is_floating_point:
|
||||
return torch.finfo(dtype).bits
|
||||
else:
|
||||
return torch.iinfo(dtype).bits
|
||||
|
||||
class GroupCoordinator:
|
||||
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
User-facing all-reduce function before we actually call the
|
||||
all-reduce operation.
|
||||
|
||||
We need this because Dynamo does not support passing an arbitrary
|
||||
object (`self` in this case) to a custom op. We need to pass the
|
||||
group name as a string, and then look up the group coordinator from
|
||||
the group name, dispatch the all-reduce operation to the group
|
||||
coordinator.
|
||||
|
||||
In addition, PyTorch custom ops do not support mutation or returning
|
||||
a new tensor in the same op. So we need to figure out if the op is
|
||||
in-place or out-of-place ahead of time.
|
||||
"""
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if self.world_size == 1:
|
||||
return input_
|
||||
if input_.is_cpu:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
ipex.distributed.all_reduce(input_, group=self.device_group)
|
||||
return input_
|
||||
|
||||
# vacc impl
|
||||
# s0, s1 = input_.shape
|
||||
# output_tensor = torch.empty([32, s0, s1],
|
||||
# dtype=input_.dtype,
|
||||
# device=input_.device)
|
||||
# torch.distributed.all_gather_into_tensor(output_tensor,
|
||||
# input_,
|
||||
# group=self.device_group)
|
||||
# input_ = output_tensor.sum(dim=0)
|
||||
torch.distributed.all_reduce(input_, group=self.device_group)
|
||||
return input_
|
||||
|
||||
def broadcast_tensor_dict(
|
||||
self,
|
||||
tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None,
|
||||
src: int = 0,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
metadata_group: Optional[ProcessGroup] = None
|
||||
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
|
||||
"""Broadcast the input tensor dictionary.
|
||||
NOTE: `src` is the local rank of the source rank.
|
||||
"""
|
||||
|
||||
# all_tensor_list = ['input_tokens','input_positions', 'slot_mapping','seq_lens_tensor', 'context_lens_tensor','block_tables', 'query_start_loc','seq_start_loc', 'selected_token_indices']
|
||||
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if (not torch.distributed.is_initialized() or self.world_size == 1):
|
||||
return tensor_dict
|
||||
|
||||
group = self.device_group
|
||||
metadata_group = self.cpu_group
|
||||
assert src < self.world_size, f"Invalid src rank ({src})"
|
||||
|
||||
rank_in_group = self.rank_in_group
|
||||
|
||||
if rank_in_group == src:
|
||||
metadata_list: List[Tuple[Any, Any]] = []
|
||||
assert isinstance(
|
||||
tensor_dict,
|
||||
dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
|
||||
# metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
|
||||
metadata_list, tensor_list = _split_tensor_dict_concat(tensor_dict)
|
||||
# `metadata_list` lives in CPU memory.
|
||||
# `broadcast_object_list` has serialization & deserialization,
|
||||
# all happening on CPU. Therefore, we can use the CPU group.
|
||||
# metadata_list 包含 (key, value) value是metadata 只有shape,没有数据
|
||||
self.broadcast_object(metadata_list, src=src)
|
||||
async_handles = []
|
||||
for tensor in tensor_list:
|
||||
if tensor.numel() == 0:
|
||||
# Skip broadcasting empty tensors.
|
||||
continue
|
||||
if tensor.is_cpu:
|
||||
# use metadata_group for CPU tensors
|
||||
handle = torch.distributed.broadcast(tensor,
|
||||
src=self.ranks[src],
|
||||
group=metadata_group,
|
||||
async_op=True)
|
||||
async_handles.append(handle)
|
||||
else:
|
||||
# use group for GPU tensors
|
||||
total_bytes = tensor.numel() * tensor.element_size()
|
||||
use_dist = True
|
||||
# only support 4M now
|
||||
if total_bytes < 4194304:
|
||||
try:
|
||||
from torch_vacc.vacc.custom_ops import broadcast
|
||||
#print("send tensor is:", tensor.shape, tensor.dtype, self.rank)
|
||||
broadcast(tensor, self.rank_in_group, self.world_size, root_rank=0, group_id=self.group_id,
|
||||
dev_info = self.rank_device_infos)
|
||||
use_dist = False
|
||||
except Exception as e:
|
||||
print("odsp broadcast run fail, now using distributed:", e)
|
||||
|
||||
if use_dist:
|
||||
handle = torch.distributed.broadcast(tensor,
|
||||
src=self.ranks[src],
|
||||
group=group,
|
||||
async_op=True)
|
||||
async_handles.append(handle)
|
||||
|
||||
for async_handle in async_handles:
|
||||
async_handle.wait()
|
||||
|
||||
else:
|
||||
# other rank
|
||||
metadata_list = self.broadcast_object(None, src=src)
|
||||
tensor_dict = {}
|
||||
async_handles = []
|
||||
tensor_size = [] # list of [key, shape] for split all_tensor
|
||||
dataType_list = []
|
||||
for key, value in metadata_list:
|
||||
# if rank_in_group == 1:
|
||||
# print('rank1 k v ', key, value)
|
||||
if isinstance(value, TensorMetadata):
|
||||
tensor = None
|
||||
|
||||
# 固定为int8
|
||||
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler, DeepseekMTPMemoryRecycler
|
||||
if isinstance(memory_recycler, DeepseekMTPMemoryRecycler) and value.dtype == torch.int8:
|
||||
tensor = memory_recycler.DYNAMIC_OUTPUT_BUFFER.view(value.dtype)[:value.size.numel()].view(value.size)
|
||||
|
||||
if tensor is None:
|
||||
tensor = torch.empty(value.size,
|
||||
dtype=value.dtype,
|
||||
device=value.device)
|
||||
if tensor.numel() == 0:
|
||||
# Skip broadcasting empty tensors.
|
||||
tensor_dict[key] = tensor
|
||||
continue
|
||||
if tensor.is_cpu:
|
||||
# use metadata_group for CPU tensors
|
||||
handle = torch.distributed.broadcast(
|
||||
tensor,
|
||||
src=self.ranks[src],
|
||||
group=metadata_group,
|
||||
async_op=True)
|
||||
async_handles.append(handle)
|
||||
tensor_dict[key] = tensor
|
||||
else:
|
||||
# use group for GPU tensors
|
||||
if key == "all_tensor":
|
||||
total_bytes = tensor.numel() * tensor.element_size()
|
||||
use_dist = True
|
||||
# only support 4M now
|
||||
if total_bytes < 4194304:
|
||||
try:
|
||||
from torch_vacc.vacc.custom_ops import broadcast
|
||||
tensor = broadcast(tensor, self.rank_in_group, self.world_size, root_rank=0, group_id=self.group_id,
|
||||
dev_info = self.rank_device_infos)
|
||||
use_dist = False
|
||||
except Exception as e:
|
||||
print("dsp brocast run fail, now using distributed:", e)
|
||||
|
||||
if use_dist:
|
||||
handle = torch.distributed.broadcast(
|
||||
tensor, #拼接的tensor
|
||||
src=self.ranks[src],
|
||||
group=group,
|
||||
async_op=False)
|
||||
|
||||
# 按 key shape对, 拆分 all_tensor, 存入tensor_dict
|
||||
start = 0
|
||||
idx = 0
|
||||
for ki_vi in tensor_size:
|
||||
ki, vi = ki_vi
|
||||
length = vi.numel() * int(get_bitwidth(dataType_list[idx]) / 8)
|
||||
if ki in MEMORY_RECYCLER_KEY:
|
||||
tensor_dict[ki] = tensor[start:start+length].view(dataType_list[idx]).view(vi)
|
||||
else:
|
||||
value_tensor = torch.empty(vi,
|
||||
dtype=dataType_list[idx],
|
||||
device=value.device)
|
||||
recv_tensor = tensor[start:start+length].view(dataType_list[idx]).view(vi)
|
||||
tensor_dict[ki] = value_tensor.copy_(recv_tensor)
|
||||
start += length
|
||||
idx += 1
|
||||
else:
|
||||
dataType_list.append(value.dtype)
|
||||
tensor_size.append([key, value.size])
|
||||
else:
|
||||
tensor_dict[key] = value
|
||||
for async_handle in async_handles:
|
||||
async_handle.wait()
|
||||
return tensor_dict
|
||||
|
||||
def all_gather(self, input_: torch.Tensor, dim: int = -1, output_: torch.Tensor = None) -> torch.Tensor:
|
||||
world_size = self.world_size
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
assert -input_.dim() <= dim < input_.dim(), (
|
||||
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
||||
|
||||
if self.use_custom_op_call:
|
||||
return torch.ops.vllm.all_gather(input_,
|
||||
dim,
|
||||
world_size,
|
||||
group_name=self.unique_name)
|
||||
else:
|
||||
# 启用输出复用版的 all_gather
|
||||
if output_ is not None:
|
||||
return self.device_communicator.all_gather_into_tensor(input_, dim, output_)
|
||||
return self._all_gather_out_place(input_, dim)
|
||||
|
||||
def recv_tensor_dict(
|
||||
self,
|
||||
src: Optional[int] = None,
|
||||
all_gather_group: Optional["GroupCoordinator"] = None,
|
||||
) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
|
||||
"""Recv the input tensor dictionary.
|
||||
NOTE: `src` is the local rank of the source rank.
|
||||
"""
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if not torch.distributed.is_initialized() or self.world_size == 1:
|
||||
return None
|
||||
all_gather_size = (1 if all_gather_group is None else
|
||||
all_gather_group.world_size)
|
||||
all_gather_rank = (0 if all_gather_group is None else
|
||||
all_gather_group.rank_in_group)
|
||||
|
||||
group = self.device_group
|
||||
metadata_group = self.cpu_group
|
||||
|
||||
if src is None:
|
||||
src = (self.rank_in_group - 1) % self.world_size
|
||||
assert src < self.world_size, f"Invalid src rank ({src})"
|
||||
|
||||
recv_metadata_list = self.recv_object(src=src)
|
||||
tensor_dict: dict[str, Any] = {}
|
||||
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import alloc_pipeline_parallel_recycler_buffer
|
||||
memory_recycler_list = ["hidden_states", "residual"]
|
||||
for key, value in recv_metadata_list:
|
||||
if isinstance(value, TensorMetadata):
|
||||
# 判断是否需要内存复用
|
||||
# 1. key在[hiddens, residual]中,说明为PP造成的
|
||||
# 2. 可以根据key,从 memory_recycling 模块中申请到tensor
|
||||
use_create_recycler_tensor = False
|
||||
tensor = None
|
||||
value_tensor = None # 用于接收all_gather总数据
|
||||
if key in memory_recycler_list:
|
||||
tensor = alloc_pipeline_parallel_recycler_buffer(value.size, value.dtype, key)
|
||||
if tensor is not None:
|
||||
use_create_recycler_tensor = True
|
||||
|
||||
if not use_create_recycler_tensor:
|
||||
tensor = torch.empty(value.size,
|
||||
dtype=value.dtype,
|
||||
device=value.device)
|
||||
|
||||
value_tensor = tensor
|
||||
|
||||
if tensor.numel() == 0:
|
||||
# Skip broadcasting empty tensors.
|
||||
tensor_dict[key] = tensor
|
||||
continue
|
||||
|
||||
# send-allgather: send only a slice, then do allgather.
|
||||
use_all_gather = (all_gather_group is not None
|
||||
and tensor.numel() % all_gather_size == 0)
|
||||
|
||||
if use_all_gather:
|
||||
orig_shape = tensor.shape
|
||||
# 内存复用,无需reshape, view即可
|
||||
if use_create_recycler_tensor:
|
||||
tensor = tensor.view(all_gather_size,
|
||||
-1)[all_gather_rank].contiguous()
|
||||
else:
|
||||
tensor = tensor.reshape(all_gather_size,
|
||||
-1)[all_gather_rank]
|
||||
|
||||
if tensor.is_cpu:
|
||||
# use metadata_group for CPU tensors
|
||||
torch.distributed.recv(tensor,
|
||||
src=self.ranks[src],
|
||||
group=metadata_group)
|
||||
else:
|
||||
# use group for GPU tensors
|
||||
torch.distributed.recv(tensor,
|
||||
src=self.ranks[src],
|
||||
group=group)
|
||||
if use_all_gather:
|
||||
# do the allgather
|
||||
if use_create_recycler_tensor:
|
||||
tensor = all_gather_group.all_gather( # type: ignore
|
||||
tensor, dim=0, output_ = value_tensor)
|
||||
tensor = tensor.view(orig_shape)
|
||||
else:
|
||||
tensor = all_gather_group.all_gather( # type: ignore
|
||||
tensor, dim=0)
|
||||
tensor = tensor.reshape(orig_shape)
|
||||
|
||||
tensor_dict[key] = tensor
|
||||
else:
|
||||
tensor_dict[key] = value
|
||||
return tensor_dict
|
||||
0
vllm_vacc/vllm/engine/__init__.py
Normal file
0
vllm_vacc/vllm/engine/__init__.py
Normal file
BIN
vllm_vacc/vllm/engine/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/engine/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_vacc/vllm/engine/__pycache__/arg_utils.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/engine/__pycache__/arg_utils.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_vacc/vllm/engine/__pycache__/llm_engine.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/engine/__pycache__/llm_engine.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_vacc/vllm/engine/__pycache__/metrics.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/engine/__pycache__/metrics.cpython-312.pyc
Normal file
Binary file not shown.
158
vllm_vacc/vllm/engine/arg_utils.py
Normal file
158
vllm_vacc/vllm/engine/arg_utils.py
Normal file
@@ -0,0 +1,158 @@
|
||||
import argparse
|
||||
import copy
|
||||
import dataclasses
|
||||
import functools
|
||||
import json
|
||||
import sys
|
||||
import threading
|
||||
import warnings
|
||||
from dataclasses import MISSING, dataclass, fields, is_dataclass
|
||||
from itertools import permutations
|
||||
from typing import (Annotated, Any, Callable, Dict, List, Literal, Optional,
|
||||
Type, TypeVar, Union, cast, get_args, get_origin)
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
from typing_extensions import TypeIs, deprecated
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.plugins import load_general_plugins
|
||||
from vllm.reasoning import ReasoningParserManager
|
||||
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
|
||||
from vllm.transformers_utils.utils import check_gguf_file
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import (FlexibleArgumentParser, GiB_bytes, get_ip,
|
||||
is_in_ray_actor)
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.platforms import CpuArchEnum, current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
def _set_default_args(self, usage_context: UsageContext,
|
||||
model_config: ModelConfig) -> None:
|
||||
"""Set Default Arguments for V1 Engine."""
|
||||
|
||||
# V1 always uses chunked prefills and prefix caching
|
||||
# for non-pooling tasks.
|
||||
# For pooling tasks the default is False
|
||||
self.enable_chunked_prefill = False
|
||||
self.enable_prefix_caching = False
|
||||
if model_config.runner_type != "pooling":
|
||||
# TODO: When prefix caching supports prompt embeds inputs, this
|
||||
# check can be removed.
|
||||
if (self.enable_prompt_embeds
|
||||
and self.enable_prefix_caching is not False):
|
||||
logger.warning(
|
||||
"--enable-prompt-embeds and --enable-prefix-caching "
|
||||
"are not supported together in V1. Prefix caching has "
|
||||
"been disabled.")
|
||||
|
||||
# V1 should use the new scheduler by default.
|
||||
# Swap it only if this arg is set to the original V0 default
|
||||
if self.scheduler_cls == EngineArgs.scheduler_cls:
|
||||
self.scheduler_cls = "vllm.v1.core.sched.scheduler.Scheduler"
|
||||
|
||||
# When no user override, set the default values based on the usage
|
||||
# context.
|
||||
# Use different default values for different hardware.
|
||||
|
||||
# Try to query the device name on the current platform. If it fails,
|
||||
# it may be because the platform that imports vLLM is not the same
|
||||
# as the platform that vLLM is running on (e.g. the case of scaling
|
||||
# vLLM with Ray) and has no GPUs. In this case we use the default
|
||||
# values for non-H100/H200 GPUs.
|
||||
try:
|
||||
device_memory = current_platform.get_device_total_memory()
|
||||
device_name = current_platform.get_device_name().lower()
|
||||
except Exception:
|
||||
# This is only used to set default_max_num_batched_tokens
|
||||
device_memory = 0
|
||||
|
||||
# NOTE(Kuntai): Setting large `max_num_batched_tokens` for A100 reduces
|
||||
# throughput, see PR #17885 for more details.
|
||||
# So here we do an extra device name check to prevent such regression.
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
if device_memory >= 70 * GiB_bytes and "a100" not in device_name:
|
||||
# For GPUs like H100 and MI300x, use larger default values.
|
||||
default_max_num_batched_tokens = {
|
||||
UsageContext.LLM_CLASS: 16384,
|
||||
UsageContext.OPENAI_API_SERVER: 8192,
|
||||
}
|
||||
default_max_num_seqs = {
|
||||
UsageContext.LLM_CLASS: 1024,
|
||||
UsageContext.OPENAI_API_SERVER: 1024,
|
||||
}
|
||||
else:
|
||||
# TODO(woosuk): Tune the default values for other hardware.
|
||||
default_max_num_batched_tokens = {
|
||||
UsageContext.LLM_CLASS: 8192,
|
||||
UsageContext.OPENAI_API_SERVER: 2048,
|
||||
}
|
||||
default_max_num_seqs = {
|
||||
UsageContext.LLM_CLASS: 4,
|
||||
UsageContext.OPENAI_API_SERVER: 4,
|
||||
}
|
||||
|
||||
# tpu specific default values.
|
||||
if current_platform.is_tpu():
|
||||
default_max_num_batched_tokens_tpu = {
|
||||
UsageContext.LLM_CLASS: {
|
||||
'V6E': 2048,
|
||||
'V5E': 1024,
|
||||
'V5P': 512,
|
||||
},
|
||||
UsageContext.OPENAI_API_SERVER: {
|
||||
'V6E': 1024,
|
||||
'V5E': 512,
|
||||
'V5P': 256,
|
||||
}
|
||||
}
|
||||
|
||||
# cpu specific default values.
|
||||
if current_platform.is_cpu():
|
||||
world_size = self.pipeline_parallel_size * self.tensor_parallel_size
|
||||
default_max_num_batched_tokens = {
|
||||
UsageContext.LLM_CLASS: 4096 * world_size,
|
||||
UsageContext.OPENAI_API_SERVER: 2048 * world_size,
|
||||
}
|
||||
default_max_num_seqs = {
|
||||
UsageContext.LLM_CLASS: 256 * world_size,
|
||||
UsageContext.OPENAI_API_SERVER: 128 * world_size,
|
||||
}
|
||||
|
||||
use_context_value = usage_context.value if usage_context else None
|
||||
if (self.max_num_batched_tokens is None
|
||||
and usage_context in default_max_num_batched_tokens):
|
||||
if current_platform.is_tpu():
|
||||
chip_name = current_platform.get_device_name()
|
||||
if chip_name in default_max_num_batched_tokens_tpu[
|
||||
usage_context]:
|
||||
self.max_num_batched_tokens = \
|
||||
default_max_num_batched_tokens_tpu[
|
||||
usage_context][chip_name]
|
||||
else:
|
||||
self.max_num_batched_tokens = \
|
||||
default_max_num_batched_tokens[usage_context]
|
||||
else:
|
||||
if not self.enable_chunked_prefill:
|
||||
self.max_num_batched_tokens = model_config.max_model_len
|
||||
else:
|
||||
self.max_num_batched_tokens = \
|
||||
default_max_num_batched_tokens[usage_context]
|
||||
logger.debug(
|
||||
"Setting max_num_batched_tokens to %d for %s usage context.",
|
||||
self.max_num_batched_tokens, use_context_value)
|
||||
|
||||
if (self.max_num_seqs is None
|
||||
and usage_context in default_max_num_seqs):
|
||||
self.max_num_seqs = min(default_max_num_seqs[usage_context],
|
||||
self.max_num_batched_tokens or sys.maxsize)
|
||||
|
||||
logger.debug("Setting max_num_seqs to %d for %s usage context.",
|
||||
self.max_num_seqs, use_context_value)
|
||||
|
||||
49
vllm_vacc/vllm/engine/llm_engine.py
Normal file
49
vllm_vacc/vllm/engine/llm_engine.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from typing import Dict, Optional
|
||||
from vllm.engine.metrics_types import StatLoggerBase
|
||||
|
||||
class LLMEngine:
|
||||
|
||||
@classmethod
|
||||
def from_engine_args(
|
||||
cls,
|
||||
engine_args: EngineArgs,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
|
||||
) -> "LLMEngine":
|
||||
"""Creates an LLM engine from the engine arguments."""
|
||||
# Create the engine configs.
|
||||
vllm_config = engine_args.create_engine_config(usage_context)
|
||||
#patch to prevent num_speculative_tokens > 1
|
||||
speculative_mode = hasattr(vllm_config, 'speculative_config')
|
||||
if speculative_mode and \
|
||||
hasattr(vllm_config.speculative_config, 'num_speculative_tokens') and \
|
||||
vllm_config.speculative_config.num_speculative_tokens != 1:
|
||||
raise ValueError(f'run_mp_engine: only support num_speculative_tokens == 1, but get {vllm_config.speculative_config.num_speculative_tokens}')
|
||||
|
||||
default_model_infos = "default"
|
||||
if speculative_mode:
|
||||
if hasattr(vllm_config.speculative_config, 'method'):
|
||||
default_model_infos = vllm_config.speculative_config.method
|
||||
|
||||
from vllm_vacc.vllm.config_manager import vllm_vacc_config_manager
|
||||
vllm_vacc_config_manager().update_model_infos(default_model_infos)
|
||||
|
||||
import vllm.envs as envs
|
||||
engine_cls = None
|
||||
if envs.VLLM_USE_V1:
|
||||
from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
|
||||
engine_cls = V1LLMEngine
|
||||
else:
|
||||
from vllm.engine.llm_engine import LLMEngine as DefaultEngine
|
||||
engine_cls = DefaultEngine
|
||||
|
||||
assert engine_cls is not None, f"LLMEngine is empty: {engine_cls}"
|
||||
|
||||
return engine_cls.from_vllm_config(
|
||||
vllm_config=vllm_config,
|
||||
usage_context=usage_context,
|
||||
stat_loggers=stat_loggers,
|
||||
disable_log_stats=engine_args.disable_log_stats,
|
||||
)
|
||||
69
vllm_vacc/vllm/engine/metrics.py
Normal file
69
vllm_vacc/vllm/engine/metrics.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from vllm.engine.metrics_types import (StatLoggerBase, Stats)
|
||||
import vllm_vacc.vllm.model_executor.models.vars as global_vars
|
||||
|
||||
class LoggingStatLogger(StatLoggerBase):
|
||||
"""LoggingStatLogger is used in LLMEngine to log to Stdout."""
|
||||
|
||||
def log(self, stats: Stats) -> None:
|
||||
from vllm.engine.metrics import local_interval_elapsed, get_throughput, logger
|
||||
"""Called by LLMEngine.
|
||||
Logs to Stdout every self.local_interval seconds."""
|
||||
|
||||
# Save tracked stats for token counters.
|
||||
self.num_prompt_tokens.append(stats.num_prompt_tokens_iter)
|
||||
self.num_generation_tokens.append(stats.num_generation_tokens_iter)
|
||||
|
||||
# Update spec decode metrics
|
||||
self.maybe_update_spec_decode_metrics(stats)
|
||||
|
||||
# Log locally every local_interval seconds.
|
||||
if local_interval_elapsed(stats.now, self.last_local_log,
|
||||
self.local_interval):
|
||||
# Compute summary metrics for tracked stats (and log them
|
||||
# to promethus if applicable).
|
||||
prompt_throughput = get_throughput(self.num_prompt_tokens,
|
||||
now=stats.now,
|
||||
last_log=self.last_local_log)
|
||||
generation_throughput = get_throughput(
|
||||
self.num_generation_tokens,
|
||||
now=stats.now,
|
||||
last_log=self.last_local_log)
|
||||
|
||||
log_fn = logger.info
|
||||
if not any((prompt_throughput, generation_throughput,
|
||||
self.last_prompt_throughput,
|
||||
self.last_generation_throughput)):
|
||||
# Avoid log noise on an idle production system
|
||||
log_fn = logger.debug
|
||||
|
||||
log_fn(
|
||||
"Avg prompt throughput: %.1f tokens/s, "
|
||||
"Avg generation throughput: %.1f tokens/s, "
|
||||
"Running: %d reqs, Swapped: %d reqs, "
|
||||
"Pending: %d reqs, GPU KV cache usage: %.1f%%, "
|
||||
"CPU KV cache usage: %.1f%%., "
|
||||
"Do sequences length: %s",
|
||||
prompt_throughput,
|
||||
generation_throughput,
|
||||
stats.num_running_sys,
|
||||
stats.num_swapped_sys,
|
||||
stats.num_waiting_sys,
|
||||
stats.gpu_cache_usage_sys * 100,
|
||||
stats.cpu_cache_usage_sys * 100,
|
||||
str(global_vars.DO_SEQ_LENS)
|
||||
)
|
||||
if (stats.cpu_prefix_cache_hit_rate >= 0
|
||||
or stats.gpu_prefix_cache_hit_rate >= 0):
|
||||
log_fn(
|
||||
"Prefix cache hit rate: GPU: %.2f%%, CPU: %.2f%%",
|
||||
stats.gpu_prefix_cache_hit_rate * 100,
|
||||
stats.cpu_prefix_cache_hit_rate * 100,
|
||||
)
|
||||
if self.spec_decode_metrics is not None:
|
||||
logger.debug(
|
||||
self._format_spec_decode_metrics_str(
|
||||
self.spec_decode_metrics))
|
||||
|
||||
self._reset(stats, prompt_throughput, generation_throughput)
|
||||
|
||||
|
||||
0
vllm_vacc/vllm/engine/multiprocessing/__init__.py
Normal file
0
vllm_vacc/vllm/engine/multiprocessing/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
103
vllm_vacc/vllm/engine/multiprocessing/engine.py
Normal file
103
vllm_vacc/vllm/engine/multiprocessing/engine.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR,
|
||||
RPCError,
|
||||
RPCProcessRequest,
|
||||
RPCAbortRequest)
|
||||
from vllm.config import VllmConfig
|
||||
import signal
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.config import (
|
||||
maybe_register_config_serialize_by_value)
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MQLLMEngine:
|
||||
|
||||
def _handle_process_request(self, request: RPCProcessRequest):
|
||||
"""Handle RPCProcessRequest by adding it to the LLMEngine."""
|
||||
request_id = request.request_id
|
||||
|
||||
if self._errored_with is not None:
|
||||
rpc_err = RPCError(request_id=request_id,
|
||||
is_engine_errored=True,
|
||||
exception=ENGINE_DEAD_ERROR(self._errored_with))
|
||||
self._send_outputs(rpc_err)
|
||||
|
||||
try:
|
||||
self.engine.add_request(
|
||||
request_id=request_id,
|
||||
prompt=request.prompt,
|
||||
params=request.params,
|
||||
lora_request=request.lora_request,
|
||||
trace_headers=request.trace_headers,
|
||||
prompt_adapter_request=request.prompt_adapter_request,
|
||||
priority=request.priority)
|
||||
|
||||
if self.log_requests:
|
||||
from vllm.engine.multiprocessing.engine import logger
|
||||
|
||||
if request.prompt.get('prompt_token_ids') is not None:
|
||||
# logger.info("Added request: %s, %s, prompt length: %s", request.request_id, request.prompt['prompt_token_ids'], len(request.prompt['prompt_token_ids']))
|
||||
logger.info("Added request: %s, prompt length: %s", request.request_id, len(request.prompt['prompt_token_ids']))
|
||||
else:
|
||||
logger.info("Added request %s.", request.request_id)
|
||||
|
||||
except Exception as e:
|
||||
# We do not set self._errored = True here, since the error
|
||||
# is due to an issue adding this request to the engine,
|
||||
# rather than an issue with the engine itself.
|
||||
is_errored = self._errored_with is not None
|
||||
rpc_err = RPCError(request_id=request_id,
|
||||
is_engine_errored=is_errored,
|
||||
exception=e)
|
||||
self._send_outputs(rpc_err)
|
||||
|
||||
# Remove request from the engine.
|
||||
self.engine.abort_request(request_id)
|
||||
|
||||
def _handle_abort_request(self, request: RPCAbortRequest):
|
||||
self.engine.abort_request(request.request_id)
|
||||
if self.log_requests:
|
||||
from vllm.engine.multiprocessing.engine import logger
|
||||
import vllm_vacc.vllm.model_executor.models.vars as global_vars
|
||||
logger.info("Aborted request: %s, prompt length: %s", request.request_id, global_vars.DO_SEQ_LENS)
|
||||
|
||||
def run_mp_engine(vllm_config: VllmConfig, usage_context: UsageContext,
|
||||
ipc_path: str, disable_log_stats: bool,
|
||||
disable_log_requests: bool, engine_alive):
|
||||
|
||||
#patch to prevent num_speculative_tokens > 1
|
||||
speculative_mode = hasattr(vllm_config, 'speculative_config')
|
||||
if speculative_mode and \
|
||||
hasattr(vllm_config.speculative_config, 'num_speculative_tokens') and \
|
||||
vllm_config.speculative_config.num_speculative_tokens != 1:
|
||||
raise ValueError(f'run_mp_engine: only support num_speculative_tokens == 1, but get {vllm_config.speculative_config.num_speculative_tokens}')
|
||||
|
||||
default_model_infos = "default"
|
||||
if speculative_mode:
|
||||
if hasattr(vllm_config.speculative_config, 'method'):
|
||||
default_model_infos = vllm_config.speculative_config.method
|
||||
|
||||
from vllm_vacc.vllm.config_manager import vllm_vacc_config_manager
|
||||
vllm_vacc_config_manager().update_model_infos(default_model_infos)
|
||||
|
||||
try:
|
||||
# Ensure we can serialize transformer config before spawning
|
||||
maybe_register_config_serialize_by_value()
|
||||
from vllm.engine.multiprocessing.engine import MQLLMEngine,signal_handler
|
||||
engine = MQLLMEngine.from_vllm_config(
|
||||
vllm_config=vllm_config,
|
||||
usage_context=usage_context,
|
||||
disable_log_stats=disable_log_stats,
|
||||
disable_log_requests=disable_log_requests,
|
||||
ipc_path=ipc_path)
|
||||
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
engine.start()
|
||||
|
||||
except BaseException as e:
|
||||
logger.exception(e)
|
||||
engine_alive.value = False
|
||||
raise e
|
||||
0
vllm_vacc/vllm/entrypoints/__init__.py
Normal file
0
vllm_vacc/vllm/entrypoints/__init__.py
Normal file
BIN
vllm_vacc/vllm/entrypoints/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/entrypoints/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_vacc/vllm/entrypoints/__pycache__/llm.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/entrypoints/__pycache__/llm.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_vacc/vllm/entrypoints/__pycache__/renderer.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/entrypoints/__pycache__/renderer.cpython-312.pyc
Normal file
Binary file not shown.
102
vllm_vacc/vllm/entrypoints/llm.py
Normal file
102
vllm_vacc/vllm/entrypoints/llm.py
Normal file
@@ -0,0 +1,102 @@
|
||||
|
||||
|
||||
import itertools
|
||||
import warnings
|
||||
from collections.abc import Sequence
|
||||
from contextlib import contextmanager
|
||||
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union,
|
||||
cast, overload)
|
||||
|
||||
import cloudpickle
|
||||
import torch.nn as nn
|
||||
from pydantic import ValidationError
|
||||
from tqdm.auto import tqdm
|
||||
from typing_extensions import TypeVar, deprecated
|
||||
|
||||
from vllm.entrypoints.utils import _validate_truncation_size
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import (RequestOutputKind, SamplingParams)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_R = TypeVar("_R", default=Any)
|
||||
|
||||
class LLM:
|
||||
|
||||
EPRECATE_LEGACY: ClassVar[bool] = True
|
||||
def _validate_and_add_requests(
|
||||
self,
|
||||
prompts: Union[PromptType, Sequence[PromptType]],
|
||||
params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
|
||||
Sequence[PoolingParams]],
|
||||
*,
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
|
||||
priority: Optional[list[int]] = None,
|
||||
) -> None:
|
||||
|
||||
if isinstance(prompts, (str, dict)):
|
||||
# Convert a single prompt to a list.
|
||||
prompts = [prompts]
|
||||
|
||||
num_requests = len(prompts)
|
||||
if isinstance(params, Sequence) and len(params) != num_requests:
|
||||
raise ValueError("The lengths of prompts and params "
|
||||
"must be the same.")
|
||||
if isinstance(lora_request,
|
||||
Sequence) and len(lora_request) != num_requests:
|
||||
raise ValueError("The lengths of prompts and lora_request "
|
||||
"must be the same.")
|
||||
|
||||
for sp in params if isinstance(params, Sequence) else (params, ):
|
||||
if isinstance(sp, SamplingParams):
|
||||
# We only care about the final output
|
||||
sp.output_kind = RequestOutputKind.FINAL_ONLY
|
||||
|
||||
# Add requests to the engine.
|
||||
it = prompts
|
||||
if use_tqdm:
|
||||
tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
|
||||
it = tqdm_func(it, desc="Adding requests")
|
||||
|
||||
if (hasattr(current_platform, 'supports_v1') and current_platform.supports_v1(current_platform)):
|
||||
batch_items = []
|
||||
model_config = self.llm_engine.model_config
|
||||
for i, prompt in enumerate(it):
|
||||
request_id = str(next(self.request_counter))
|
||||
# print("requset_id===========", request_id)
|
||||
param = params[i] if isinstance(params, Sequence) else params
|
||||
tokenization_kwargs: dict[str, Any] = {}
|
||||
_validate_truncation_size(model_config.max_model_len,
|
||||
param.truncate_prompt_tokens,
|
||||
tokenization_kwargs)
|
||||
|
||||
batch_items.append((
|
||||
request_id,
|
||||
prompt,
|
||||
params[i] if isinstance(params, Sequence) else params,
|
||||
None, # arrival_time,不用的话传 None
|
||||
(lora_request[i] if isinstance(lora_request, Sequence)
|
||||
else lora_request),
|
||||
tokenization_kwargs,
|
||||
None, # trace_headers(如无 APM/Tracing,None)
|
||||
(priority[i] if priority else 0),
|
||||
))
|
||||
# 一次性下发给 EngineCore(走 ADD_BULK)
|
||||
self.llm_engine.add_requests(batch_items)
|
||||
else:
|
||||
for i, prompt in enumerate(it):
|
||||
self._add_request(
|
||||
prompt,
|
||||
params[i] if isinstance(params, Sequence) else params,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request[i] if isinstance(
|
||||
lora_request, Sequence) else lora_request,
|
||||
priority=priority[i] if priority else 0,
|
||||
)
|
||||
0
vllm_vacc/vllm/entrypoints/openai/__init__.py
Normal file
0
vllm_vacc/vllm/entrypoints/openai/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
345
vllm_vacc/vllm/entrypoints/openai/serving_completion.py
Normal file
345
vllm_vacc/vllm/entrypoints/openai/serving_completion.py
Normal file
@@ -0,0 +1,345 @@
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from collections.abc import Sequence as GenericSequence
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
import jinja2
|
||||
from fastapi import Request
|
||||
from typing_extensions import assert_never
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseChoice,
|
||||
CompletionResponseStreamChoice,
|
||||
CompletionStreamResponse,
|
||||
ErrorResponse,
|
||||
RequestResponseMetadata,
|
||||
UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import (
|
||||
EmbedsPrompt as ServingEngineEmbedsPrompt)
|
||||
from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
|
||||
TextTokensPrompt,
|
||||
clamp_prompt_logprobs,
|
||||
is_text_tokens_prompt)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.utils import get_max_tokens
|
||||
from vllm.inputs.data import (EmbedsPrompt, TokensPrompt, is_embeds_prompt,
|
||||
is_tokens_prompt)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import merge_async_iterators
|
||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||
|
||||
from vllm.entrypoints.openai.serving_completion import logger
|
||||
from vllm.utils import (is_list_of, make_async, merge_async_iterators,
|
||||
random_uuid)
|
||||
from vllm.utils import (AsyncMicrobatchTokenizer, is_list_of,
|
||||
merge_async_iterators, random_uuid)
|
||||
from vllm_vacc.vllm.model_executor.models.vars import LLM_MAX_PREFILL_SEQ_LEN
|
||||
|
||||
|
||||
class OpenAIServingCompletion(OpenAIServing):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
enable_prompt_tokens_details: bool = False,
|
||||
enable_force_include_usage: bool = False,
|
||||
enable_strict_batch_barrier: bool = True,
|
||||
log_error_stack: bool = False,
|
||||
):
|
||||
|
||||
self.engine_client = engine_client
|
||||
self.model_config = model_config
|
||||
self.max_model_len = model_config.max_model_len
|
||||
|
||||
self.models = models
|
||||
|
||||
self.request_logger = request_logger
|
||||
self.return_tokens_as_token_ids = return_tokens_as_token_ids
|
||||
self.enable_force_include_usage = enable_force_include_usage
|
||||
|
||||
self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
self._async_tokenizer_pool: dict[AnyTokenizer,
|
||||
AsyncMicrobatchTokenizer] = {}
|
||||
self.log_error_stack = log_error_stack
|
||||
|
||||
self.enable_prompt_tokens_details = enable_prompt_tokens_details
|
||||
self.default_sampling_params = (
|
||||
self.model_config.get_diff_sampling_param())
|
||||
if self.default_sampling_params:
|
||||
source = self.model_config.generation_config
|
||||
source = "model" if source == "auto" else source
|
||||
logger.info("Using default completion sampling params from %s: %s",
|
||||
source, self.default_sampling_params)
|
||||
self.enable_strict_batch_barrier = enable_strict_batch_barrier
|
||||
|
||||
|
||||
async def create_completion(
|
||||
self,
|
||||
request: CompletionRequest,
|
||||
raw_request: Optional[Request] = None,
|
||||
) -> Union[AsyncGenerator[str, None], CompletionResponse, ErrorResponse]:
|
||||
"""Completion API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/completions/create
|
||||
for the API specification. This API mimics the OpenAI Completion API.
|
||||
|
||||
NOTE: Currently we do not support the following feature:
|
||||
- suffix (the language models we currently support do not support
|
||||
suffix)
|
||||
"""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
# If the engine is dead, raise the engine's DEAD_ERROR.
|
||||
# This is required for the streaming case, where we return a
|
||||
# success status before we actually start generating text :).
|
||||
if self.engine_client.errored:
|
||||
raise self.engine_client.dead_error
|
||||
|
||||
# Return error for unsupported features.
|
||||
if request.suffix is not None:
|
||||
return self.create_error_response(
|
||||
"suffix is not currently supported")
|
||||
|
||||
if request.echo and request.prompt_embeds is not None:
|
||||
return self.create_error_response(
|
||||
"Echo is unsupported with prompt embeds.")
|
||||
|
||||
if (request.prompt_logprobs is not None
|
||||
and request.prompt_embeds is not None):
|
||||
return self.create_error_response(
|
||||
"prompt_logprobs is not compatible with prompt embeds.")
|
||||
|
||||
request_id = (
|
||||
f"cmpl-"
|
||||
f"{self._base_request_id(raw_request, request.request_id)}")
|
||||
created_time = int(time.time())
|
||||
|
||||
request_metadata = RequestResponseMetadata(request_id=request_id)
|
||||
if raw_request:
|
||||
raw_request.state.request_metadata = request_metadata
|
||||
|
||||
try:
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
|
||||
if self.model_config.skip_tokenizer_init:
|
||||
tokenizer = None
|
||||
else:
|
||||
tokenizer = await self.engine_client.get_tokenizer()
|
||||
renderer = self._get_renderer(tokenizer)
|
||||
|
||||
engine_prompts = await renderer.render_prompt_and_embeds(
|
||||
prompt_or_prompts=request.prompt,
|
||||
prompt_embeds=request.prompt_embeds,
|
||||
deepstack_input_embeds=request.deepstack_input_embeds if hasattr(request, 'deepstack_input_embeds') else None,
|
||||
config=self._build_render_config(request),
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
except TypeError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
except RuntimeError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
except jinja2.TemplateError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: list[AsyncGenerator[RequestOutput, None]] = []
|
||||
try:
|
||||
total_num_prompts = len(engine_prompts)
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
sampling_params: Union[SamplingParams, BeamSearchParams]
|
||||
# Mypy does not infer that engine_prompt will have only one of
|
||||
# "prompt_token_ids" or "prompt_embeds" defined, and both of
|
||||
# these as Union[object, the expected type], where it infers
|
||||
# object if engine_prompt is a subclass of one of the
|
||||
# typeddicts that defines both keys. Worse, because of
|
||||
# https://github.com/python/mypy/issues/8586, mypy does not
|
||||
# infer the type of engine_prompt correctly because of the
|
||||
# enumerate. So we need an unnecessary cast here.
|
||||
engine_prompt = cast(Union[EmbedsPrompt, TokensPrompt],
|
||||
engine_prompt)
|
||||
if is_embeds_prompt(engine_prompt):
|
||||
input_length = len(engine_prompt["prompt_embeds"])
|
||||
elif is_tokens_prompt(engine_prompt):
|
||||
input_length = len(engine_prompt["prompt_token_ids"])
|
||||
if input_length > LLM_MAX_PREFILL_SEQ_LEN:
|
||||
raise ValueError(
|
||||
f"This model's maximum input seq length limit is "
|
||||
f"{LLM_MAX_PREFILL_SEQ_LEN} tokens. However, you requested "
|
||||
f"({input_length} in the input messages, "
|
||||
f"Please reduce the length of the input messages.")
|
||||
else:
|
||||
assert_never(engine_prompt)
|
||||
|
||||
if self.default_sampling_params is None:
|
||||
self.default_sampling_params = {}
|
||||
|
||||
max_tokens = get_max_tokens(
|
||||
max_model_len=self.max_model_len,
|
||||
request=request,
|
||||
input_length=input_length,
|
||||
default_sampling_params=self.default_sampling_params,
|
||||
)
|
||||
|
||||
if request.use_beam_search:
|
||||
sampling_params = request.to_beam_search_params(
|
||||
max_tokens, self.default_sampling_params)
|
||||
else:
|
||||
sampling_params = request.to_sampling_params(
|
||||
max_tokens,
|
||||
self.model_config.logits_processor_pattern,
|
||||
self.default_sampling_params,
|
||||
)
|
||||
|
||||
# Inject strict batch barrier metadata so this batch is held
|
||||
# until all items are ready, then scheduled together.
|
||||
if (self.enable_strict_batch_barrier
|
||||
and total_num_prompts > 1
|
||||
and isinstance(sampling_params, SamplingParams)):
|
||||
if sampling_params.extra_args is None:
|
||||
sampling_params.extra_args = {}
|
||||
sampling_params.extra_args.setdefault("barrier_group_id",
|
||||
request_id)
|
||||
sampling_params.extra_args.setdefault("barrier_group_size",
|
||||
total_num_prompts)
|
||||
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
self._log_inputs(
|
||||
request_id_item,
|
||||
engine_prompt,
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
trace_headers = (None if raw_request is None else await
|
||||
self._get_trace_headers(raw_request.headers))
|
||||
|
||||
# Mypy inconsistently requires this second cast in different
|
||||
# environments. It shouldn't be necessary (redundant from above)
|
||||
# but pre-commit in CI fails without it.
|
||||
engine_prompt = cast(Union[EmbedsPrompt, TokensPrompt],
|
||||
engine_prompt)
|
||||
if isinstance(sampling_params, BeamSearchParams):
|
||||
generator = self.engine_client.beam_search(
|
||||
prompt=engine_prompt,
|
||||
request_id=request_id,
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
else:
|
||||
generator = self.engine_client.generate(
|
||||
engine_prompt,
|
||||
sampling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
logger.error(e)
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
result_generator = merge_async_iterators(*generators)
|
||||
|
||||
model_name = self.models.model_name(lora_request)
|
||||
num_prompts = len(engine_prompts)
|
||||
|
||||
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
||||
# results. Noting that best_of is only supported in V0. In addition,
|
||||
# we do not stream the results when use beam search.
|
||||
stream = (request.stream
|
||||
and (request.best_of is None or request.n == request.best_of)
|
||||
and not request.use_beam_search)
|
||||
|
||||
# Streaming response
|
||||
if stream:
|
||||
return self.completion_stream_generator(
|
||||
request,
|
||||
engine_prompts,
|
||||
result_generator,
|
||||
request_id,
|
||||
created_time,
|
||||
model_name,
|
||||
num_prompts=num_prompts,
|
||||
tokenizer=tokenizer,
|
||||
request_metadata=request_metadata,
|
||||
enable_force_include_usage=self.enable_force_include_usage,
|
||||
)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: list[Optional[RequestOutput]] = [None] * num_prompts
|
||||
try:
|
||||
async for i, res in result_generator:
|
||||
final_res_batch[i] = res
|
||||
|
||||
for i, final_res in enumerate(final_res_batch):
|
||||
assert final_res is not None
|
||||
|
||||
# The output should contain the input text
|
||||
# We did not pass it into vLLM engine to avoid being redundant
|
||||
# with the inputs token IDs
|
||||
if final_res.prompt is None:
|
||||
engine_prompt = engine_prompts[i]
|
||||
final_res.prompt = None if is_embeds_prompt(
|
||||
engine_prompt) else engine_prompt.get("prompt")
|
||||
|
||||
final_res_batch_checked = cast(list[RequestOutput],
|
||||
final_res_batch)
|
||||
|
||||
response = self.request_output_to_completion_response(
|
||||
final_res_batch_checked,
|
||||
request,
|
||||
request_id,
|
||||
created_time,
|
||||
model_name,
|
||||
tokenizer,
|
||||
request_metadata,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
# When user requests streaming but we don't stream, we still need to
|
||||
# return a streaming response with a single event.
|
||||
if request.stream:
|
||||
response_json = response.model_dump_json()
|
||||
|
||||
async def fake_stream_generator() -> AsyncGenerator[str, None]:
|
||||
yield f"data: {response_json}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return fake_stream_generator()
|
||||
|
||||
return response
|
||||
191
vllm_vacc/vllm/entrypoints/openai/serving_engine.py
Normal file
191
vllm_vacc/vllm/entrypoints/openai/serving_engine.py
Normal file
@@ -0,0 +1,191 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import json
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from http import HTTPStatus
|
||||
from typing import (Any, Callable, Dict, Iterable, Iterator, List, Mapping,
|
||||
Optional, Sequence, Tuple, TypedDict, Union)
|
||||
|
||||
from fastapi import Request
|
||||
from pydantic import Field
|
||||
from starlette.datastructures import Headers
|
||||
from typing_extensions import Annotated
|
||||
import torch
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
|
||||
ChatTemplateContentFormatOption,
|
||||
ConversationMessage,
|
||||
apply_hf_chat_template,
|
||||
apply_mistral_chat_template,
|
||||
parse_chat_messages_futures,
|
||||
resolve_chat_template_content_format)
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
CompletionRequest,
|
||||
DetokenizeRequest,
|
||||
EmbeddingChatRequest,
|
||||
EmbeddingCompletionRequest,
|
||||
ErrorResponse, RerankRequest,
|
||||
ScoreRequest,
|
||||
TokenizeChatRequest,
|
||||
TokenizeCompletionRequest)
|
||||
# yapf: enable
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.entrypoints.openai.serving_engine import AnyRequest, TextTokensPrompt
|
||||
# from vllm.model_executor.sampling_metadata import _SAMPLING_EPS
|
||||
from vllm.v1.sample.sampler import _SAMPLING_EPS
|
||||
import os
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
from vllm_vacc.vllm.model_executor.models.vars import LLM_MAX_PREFILL_SEQ_LEN
|
||||
from vllm_vacc.vllm.model_executor.models.vars import CUT_PREFILL_SEQ_LEN
|
||||
|
||||
class EmbedsPrompt(TypedDict):
|
||||
prompt_embeds: torch.Tensor
|
||||
deepstack_input_embeds: Optional[dict]
|
||||
|
||||
class OpenAIServing:
|
||||
def _validate_input(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
input_ids: List[int],
|
||||
input_text: str,
|
||||
) -> TextTokensPrompt:
|
||||
# clint 设置的参数, 如果没有设, 还会再从 generation_config.json 读取
|
||||
if CUT_PREFILL_SEQ_LEN > 0 and CUT_PREFILL_SEQ_LEN < len(input_ids):
|
||||
cut_before = CUT_PREFILL_SEQ_LEN // 2
|
||||
cut_after = CUT_PREFILL_SEQ_LEN - cut_before
|
||||
input_ids = input_ids[:cut_before] + input_ids[(-1)*cut_after:]
|
||||
token_num = len(input_ids)
|
||||
|
||||
if not self.model_config.pooler_config:
|
||||
if (request.repetition_penalty is not None and abs(request.repetition_penalty - 1.0) >= _SAMPLING_EPS):
|
||||
raise ValueError(
|
||||
f"unsupport penalty for sampler"
|
||||
f"request.repetition_penalty: {request.repetition_penalty}; "
|
||||
f"Please remove penalty parameter in client and try again."
|
||||
)
|
||||
if request.min_p is not None and request.min_p > _SAMPLING_EPS:
|
||||
raise ValueError(f"unsupport min_p {request.min_p} for sampler")
|
||||
if request.prompt_logprobs is not None:
|
||||
raise ValueError(f"unsupport prompt_logprobs {request.prompt_logprobs} for sampler")
|
||||
|
||||
if request.min_p is not None and request.min_p > _SAMPLING_EPS:
|
||||
raise ValueError(f"unsupport min_p {request.min_p} for sampler")
|
||||
if request.prompt_logprobs is not None:
|
||||
raise ValueError(f"unsupport prompt_logprobs {request.prompt_logprobs} for sampler")
|
||||
|
||||
# model_type = self.model_config.hf_config.model_type
|
||||
# if model_type == "deepseek_v3":
|
||||
if token_num > LLM_MAX_PREFILL_SEQ_LEN:
|
||||
raise ValueError(
|
||||
f"This model's maximum input seq length limit is "
|
||||
f"{LLM_MAX_PREFILL_SEQ_LEN} tokens. However, you requested "
|
||||
f"({token_num} in the input messages, "
|
||||
f"Please reduce the length of the input messages.")
|
||||
|
||||
# Note: EmbeddingRequest and ScoreRequest doesn't have max_tokens
|
||||
if isinstance(request,
|
||||
(EmbeddingChatRequest, EmbeddingCompletionRequest,
|
||||
ScoreRequest, RerankRequest)):
|
||||
|
||||
operation = "score" if isinstance(request, ScoreRequest) \
|
||||
else "embedding generation"
|
||||
if token_num > self.max_model_len:
|
||||
raise ValueError(
|
||||
f"This model's maximum context length is "
|
||||
f"{self.max_model_len} tokens. However, you requested "
|
||||
f"{token_num} tokens in the input for {operation}. "
|
||||
f"Please reduce the length of the input.")
|
||||
return TextTokensPrompt(prompt=input_text,
|
||||
prompt_token_ids=input_ids)
|
||||
|
||||
# Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
|
||||
# and does not require model context length validation
|
||||
if isinstance(request, (TokenizeCompletionRequest, TokenizeChatRequest,
|
||||
DetokenizeRequest)):
|
||||
return TextTokensPrompt(prompt=input_text,
|
||||
prompt_token_ids=input_ids)
|
||||
|
||||
# chat completion endpoint supports max_completion_tokens
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
# TODO(#9845): remove max_tokens when field dropped from OpenAI API
|
||||
max_tokens = request.max_completion_tokens or request.max_tokens
|
||||
else:
|
||||
max_tokens = request.max_tokens
|
||||
if max_tokens is None:
|
||||
if token_num >= self.max_model_len:
|
||||
raise ValueError(
|
||||
f"This model's maximum context length is "
|
||||
f"{self.max_model_len} tokens. However, you requested "
|
||||
f"{token_num} tokens in the messages, "
|
||||
f"Please reduce the length of the messages.")
|
||||
elif token_num + max_tokens > self.max_model_len:
|
||||
raise ValueError(
|
||||
f"This model's maximum context length is "
|
||||
f"{self.max_model_len} tokens. However, you requested "
|
||||
f"{max_tokens + token_num} tokens "
|
||||
f"({token_num} in the messages, "
|
||||
f"{max_tokens} in the completion). "
|
||||
f"Please reduce the length of the messages or completion.")
|
||||
|
||||
|
||||
return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
|
||||
|
||||
def _log_inputs(
|
||||
self,
|
||||
request_id: str,
|
||||
inputs,
|
||||
params: Optional[Union[SamplingParams, PoolingParams,
|
||||
BeamSearchParams]],
|
||||
lora_request: Optional[LoRARequest],
|
||||
) -> None:
|
||||
# move to position where before use request_logger
|
||||
# if self.request_logger is None:
|
||||
# return
|
||||
# if self.model_config.pooler_config is not None, task is embedding , not generation task
|
||||
if self.model_config.pooler_config:
|
||||
return
|
||||
prompt, prompt_token_ids, prompt_embeds = None, None, None
|
||||
if isinstance(inputs, str):
|
||||
prompt = inputs
|
||||
elif isinstance(inputs, list):
|
||||
prompt_token_ids = inputs
|
||||
else:
|
||||
prompt = getattr(inputs, 'prompt', None)
|
||||
prompt_token_ids = getattr(inputs, 'prompt_token_ids', None)
|
||||
|
||||
# generation_config 读取的惩罚信息, 如果有,则警告并且修改
|
||||
if (params.repetition_penalty is not None and abs(params.repetition_penalty - 1.0) >= _SAMPLING_EPS):
|
||||
logger.warning(
|
||||
"\033[93mWARNING \033[0m"
|
||||
": Unsupport penalty for sampler"
|
||||
f"params.repetition_penalty: {params.repetition_penalty} and "
|
||||
"Please set attrs: extra_body = {\'repetition_penalty\': 1.0}\n"
|
||||
"Now set: repetition_penalty: 1.0"
|
||||
)
|
||||
# params.presence_penalty = 0
|
||||
# params.frequency_penalty = 0
|
||||
params.repetition_penalty = 1
|
||||
|
||||
if hasattr(params, "min_p") and params.min_p is not None and params.min_p > _SAMPLING_EPS:
|
||||
logger.warning(f"\033[93mWARNING \033[0m : unsupport min_p {params.min_p} for sampler")
|
||||
params.min_p = 0
|
||||
if self.request_logger is None:
|
||||
return
|
||||
self.request_logger.log_inputs(
|
||||
request_id,
|
||||
prompt,
|
||||
prompt_token_ids,
|
||||
prompt_embeds,
|
||||
params=params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
127
vllm_vacc/vllm/entrypoints/renderer.py
Normal file
127
vllm_vacc/vllm/entrypoints/renderer.py
Normal file
@@ -0,0 +1,127 @@
|
||||
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated, Optional, Union
|
||||
|
||||
import pybase64
|
||||
import torch
|
||||
from pydantic import Field
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt
|
||||
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
||||
from vllm.inputs.parse import parse_and_batch_prompt
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import AsyncMicrobatchTokenizer
|
||||
|
||||
|
||||
|
||||
class BaseRenderer(ABC):
|
||||
"""
|
||||
Base class for unified input processing and rendering.
|
||||
|
||||
The Renderer serves as a unified input processor that consolidates
|
||||
tokenization, chat template formatting, and multimodal input handling
|
||||
into a single component.
|
||||
It converts high-level API requests (OpenAI-style JSON) into token IDs and
|
||||
multimodal features ready for engine consumption.
|
||||
|
||||
Key responsibilities:
|
||||
- Convert text prompts to token sequences with proper special tokens
|
||||
- Apply chat templates and format conversations
|
||||
- Handle multimodal inputs (images, audio, etc.) when applicable
|
||||
- Manage prompt truncation and length validation
|
||||
- Provide clean separation between API layer and engine core
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def load_prompt_embeds(
|
||||
cls,
|
||||
prompt_embeds: Union[bytes, list[bytes]],
|
||||
deepstack_input_embeds: Optional[dict[str, Union[bytes, str]]] = None,
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=0)]] = None,
|
||||
cache_salt: Optional[str] = None,
|
||||
) -> list[EngineEmbedsPrompt]:
|
||||
"""Load and validate base64-encoded embeddings into prompt objects."""
|
||||
|
||||
def _load_and_validate_embed(embed: bytes) -> EngineEmbedsPrompt:
|
||||
tensor = torch.load(
|
||||
io.BytesIO(pybase64.b64decode(embed, validate=True)),
|
||||
weights_only=True,
|
||||
map_location=torch.device("cpu"),
|
||||
)
|
||||
assert isinstance(tensor, torch.Tensor) and tensor.dtype in (
|
||||
torch.float32,
|
||||
torch.bfloat16,
|
||||
torch.float16,
|
||||
)
|
||||
tensor = tensor.to_dense()
|
||||
if tensor.dim() > 2:
|
||||
tensor = tensor.squeeze(0)
|
||||
assert tensor.dim() == 2
|
||||
if truncate_prompt_tokens is not None:
|
||||
tensor = tensor[-truncate_prompt_tokens:]
|
||||
embeds_prompt = EngineEmbedsPrompt(prompt_embeds=tensor)
|
||||
if cache_salt is not None:
|
||||
embeds_prompt["cache_salt"] = cache_salt
|
||||
|
||||
if deepstack_input_embeds is not None:
|
||||
all_tensor = []
|
||||
from vllm.sequence import IntermediateTensors
|
||||
tensor_dict = torch.load(
|
||||
io.BytesIO(pybase64.b64decode(deepstack_input_embeds, validate=True))
|
||||
)
|
||||
for k in tensor_dict:
|
||||
all_tensor.append(tensor_dict[k].unsqueeze(0))
|
||||
|
||||
all_tensor = torch.concatenate(all_tensor, 0)
|
||||
embeds_prompt["deepstack_input_embeds"] = all_tensor #IntermediateTensors(tensors=tensor_dict)
|
||||
|
||||
return embeds_prompt
|
||||
|
||||
if isinstance(prompt_embeds, list):
|
||||
return [_load_and_validate_embed(embed) for embed in prompt_embeds]
|
||||
|
||||
return [_load_and_validate_embed(prompt_embeds)]
|
||||
|
||||
|
||||
|
||||
class CompletionRenderer(BaseRenderer):
|
||||
|
||||
async def render_prompt_and_embeds(
|
||||
self,
|
||||
*,
|
||||
prompt_or_prompts: Optional[Union[str, list[str], list[int],
|
||||
list[list[int]]]] = None,
|
||||
prompt_embeds: Optional[Union[bytes, list[bytes]]] = None,
|
||||
deepstack_input_embeds: Optional[Union[bytes, list[bytes]]] = None,
|
||||
config: "RenderConfig",
|
||||
) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]:
|
||||
"""
|
||||
Render text/token prompts and/or precomputed embedding prompts. At
|
||||
least one of `prompt_or_prompts` or `prompt_embeds` must be provided.
|
||||
"""
|
||||
truncate_prompt_tokens = self._validate_and_normalize_truncate_tokens(
|
||||
config.truncate_prompt_tokens, config.max_length)
|
||||
if truncate_prompt_tokens == 0:
|
||||
return []
|
||||
|
||||
rendered: list[Union[EngineTokensPrompt, EngineEmbedsPrompt]] = []
|
||||
|
||||
if prompt_embeds is not None:
|
||||
rendered.extend(
|
||||
self.load_prompt_embeds(prompt_embeds, deepstack_input_embeds, truncate_prompt_tokens,
|
||||
config.cache_salt))
|
||||
if prompt_or_prompts is None or prompt_or_prompts == "":
|
||||
return rendered
|
||||
|
||||
token_prompts = await self.render_prompt(
|
||||
prompt_or_prompts=prompt_or_prompts,
|
||||
config=config,
|
||||
)
|
||||
rendered.extend(token_prompts)
|
||||
|
||||
return rendered
|
||||
0
vllm_vacc/vllm/executor/__init__.py
Normal file
0
vllm_vacc/vllm/executor/__init__.py
Normal file
BIN
vllm_vacc/vllm/executor/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/executor/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
20
vllm_vacc/vllm/executor/executor_base.py
Normal file
20
vllm_vacc/vllm/executor/executor_base.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import asyncio
|
||||
from typing import List
|
||||
|
||||
from vllm.v1.outputs import PoolerOutput, SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
|
||||
# class DistributedExecutorBase():
|
||||
# """Abstract superclass of distributed executor implementations."""
|
||||
|
||||
async def execute_model_async(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||
if self.parallel_worker_tasks is None:
|
||||
# Start model execution loop running in the parallel workers
|
||||
self.parallel_worker_tasks = asyncio.create_task(
|
||||
self._start_worker_execution_loop())
|
||||
await asyncio.sleep(0)
|
||||
# Only the driver worker returns the sampling results.
|
||||
await asyncio.sleep(0)
|
||||
return await self._driver_execute_model_async(execute_model_req)
|
||||
0
vllm_vacc/vllm/inputs/__init__.py
Normal file
0
vllm_vacc/vllm/inputs/__init__.py
Normal file
BIN
vllm_vacc/vllm/inputs/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/inputs/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_vacc/vllm/inputs/__pycache__/data.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/inputs/__pycache__/data.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm_vacc/vllm/inputs/__pycache__/preprocess.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/inputs/__pycache__/preprocess.cpython-312.pyc
Normal file
Binary file not shown.
55
vllm_vacc/vllm/inputs/data.py
Normal file
55
vllm_vacc/vllm/inputs/data.py
Normal file
@@ -0,0 +1,55 @@
|
||||
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, Union, cast
|
||||
|
||||
import torch
|
||||
from typing_extensions import NotRequired, TypedDict, TypeIs, TypeVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalInputs,
|
||||
MultiModalUUIDDict)
|
||||
|
||||
class EmbedsPrompt(TypedDict):
|
||||
"""Schema for a prompt provided via token embeddings."""
|
||||
|
||||
prompt_embeds: torch.Tensor
|
||||
"""The embeddings of the prompt."""
|
||||
from vllm.sequence import IntermediateTensors
|
||||
deepstack_input_embeds: Optional[IntermediateTensors]
|
||||
cache_salt: NotRequired[str]
|
||||
"""
|
||||
Optional cache salt to be used for prefix caching.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
class EmbedsInputs(TypedDict):
|
||||
"""Represents embeddings-based inputs."""
|
||||
|
||||
type: Literal["embeds"]
|
||||
"""The type of inputs."""
|
||||
|
||||
prompt_embeds: torch.Tensor
|
||||
"""The embeddings of the prompt."""
|
||||
deepstack_input_embeds: torch.Tensor
|
||||
|
||||
cache_salt: NotRequired[str]
|
||||
"""
|
||||
Optional cache salt to be used for prefix caching.
|
||||
"""
|
||||
|
||||
|
||||
def embeds_inputs(
|
||||
prompt_embeds: torch.Tensor,
|
||||
deepstack_input_embeds: torch.Tensor,
|
||||
cache_salt: Optional[str] = None,
|
||||
) -> EmbedsInputs:
|
||||
"""Construct [`EmbedsInputs`][vllm.inputs.data.EmbedsInputs] from optional
|
||||
values."""
|
||||
inputs = EmbedsInputs(type="embeds", prompt_embeds=prompt_embeds, deepstack_input_embeds=deepstack_input_embeds)
|
||||
|
||||
if cache_salt is not None:
|
||||
inputs["cache_salt"] = cache_salt
|
||||
|
||||
return inputs
|
||||
54
vllm_vacc/vllm/inputs/preprocess.py
Normal file
54
vllm_vacc/vllm/inputs/preprocess.py
Normal file
@@ -0,0 +1,54 @@
|
||||
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.multimodal.cache import BaseMultiModalProcessorCache
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
|
||||
MultiModalInputs, MultiModalUUIDDict)
|
||||
from vllm.multimodal.processing import BaseMultiModalProcessor
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
from .data import EmbedsInputs, EmbedsPrompt, embeds_inputs
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class InputPreprocessor:
|
||||
def _process_embeds(
|
||||
self,
|
||||
parsed_content: EmbedsPrompt,
|
||||
) -> EmbedsInputs:
|
||||
if not self.model_config.enable_prompt_embeds:
|
||||
raise ValueError("You must set `--enable-prompt-embeds` to input "
|
||||
"`prompt_embeds`.")
|
||||
|
||||
prompt_embeds = parsed_content["prompt_embeds"]
|
||||
deepstack_input_embeds = None
|
||||
if 'deepstack_input_embeds' in parsed_content:
|
||||
deepstack_input_embeds = parsed_content["deepstack_input_embeds"]
|
||||
|
||||
# prompt_embeds must be (seq_len, hidden_size), but if the user
|
||||
# passes in a batch of size 1, i.e. (1, seq_len, hidden_size),
|
||||
# we can unambiguously process the intent by squeezing the batch
|
||||
# dimension.
|
||||
if prompt_embeds.ndim == 3:
|
||||
prompt_embeds = prompt_embeds.squeeze(dim=0)
|
||||
|
||||
if prompt_embeds.ndim != 2:
|
||||
raise ValueError(
|
||||
"prompt_embeds must be of shape (seq_len, hidden_size).")
|
||||
|
||||
# Tensors must be on CPU for serialization between processes
|
||||
# in the MsgpackEncoder. Casting to CPU here ensures that there is no
|
||||
# hidden device transfer in the critical path of generation.
|
||||
prompt_embeds = prompt_embeds.cpu()
|
||||
|
||||
return embeds_inputs(prompt_embeds=prompt_embeds,
|
||||
deepstack_input_embeds=deepstack_input_embeds,
|
||||
cache_salt=parsed_content.get("cache_salt"))
|
||||
0
vllm_vacc/vllm/model_executor/__init__.py
Normal file
0
vllm_vacc/vllm/model_executor/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
38
vllm_vacc/vllm/model_executor/custom_op.py
Normal file
38
vllm_vacc/vllm/model_executor/custom_op.py
Normal file
@@ -0,0 +1,38 @@
|
||||
|
||||
import torch.nn as nn
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
class CustomOp(nn.Module):
|
||||
|
||||
def forward_vacc(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def dispatch_forward(self):
|
||||
# NOTE(woosuk): Here we assume that vLLM was built for only one
|
||||
# specific backend. Currently, we do not support dynamic dispatching.
|
||||
|
||||
enabled = self.enabled()
|
||||
logger.debug("custom op %s %s", self.__class__.name,
|
||||
"enabled" if enabled else "disabled")
|
||||
|
||||
if not enabled:
|
||||
return self.forward_native
|
||||
|
||||
return self.forward
|
||||
|
||||
if current_platform.is_rocm():
|
||||
return self.forward_hip
|
||||
elif current_platform.is_cpu():
|
||||
return self.forward_cpu
|
||||
elif current_platform.is_hpu():
|
||||
return self.forward_hpu
|
||||
elif current_platform.is_tpu():
|
||||
return self.forward_tpu
|
||||
elif current_platform.is_xpu():
|
||||
return self.forward_xpu
|
||||
elif current_platform.is_vacc():
|
||||
return self.forward
|
||||
else:
|
||||
return self.forward_cuda
|
||||
0
vllm_vacc/vllm/model_executor/layers/__init__.py
Normal file
0
vllm_vacc/vllm/model_executor/layers/__init__.py
Normal file
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user