There is a lot hack code for v0.11.0, which makes the code hard to
upgrade to newer vLLM version. Since v0.11.0 will release soon. Let's
drop v0.11.0 support first. Then we'll upgrade to v0.11.2 soon.
- vLLM version: v0.11.0
- vLLM main:
2918c1b49c
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
535 lines
17 KiB
Python
535 lines
17 KiB
Python
import array
|
|
import hashlib
|
|
import json
|
|
import os
|
|
import re
|
|
from dataclasses import dataclass
|
|
from typing import Iterable, List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import \
|
|
KVConnectorMetadata
|
|
from vllm.utils import logger
|
|
from vllm.utils.math_utils import cdiv
|
|
from vllm.v1.core.sched.output import NewRequestData
|
|
|
|
DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB
|
|
DEFAULT_LOCAL_BUFFER_SIZE = 1073741824 # 1.0 GiB
|
|
|
|
|
|
@dataclass
|
|
class MooncakeEngineMetadata:
|
|
"""name of the LLM model"""
|
|
|
|
model_name: str
|
|
""" world size when running under a distributed setting """
|
|
world_size: int
|
|
""" worker id when running under a distributed setting """
|
|
worker_id: int
|
|
""" the format of kv tensors """
|
|
kv_dtype: torch.dtype
|
|
""" the shape of kv tensors """
|
|
""" (num_layer, 2, metadata.block_size, num_kv_head, head_size) """
|
|
kv_shape: tuple[int, int, int, int, int]
|
|
block_size: int = 128
|
|
""" whether use MLA"""
|
|
use_mla: bool = False
|
|
|
|
|
|
@dataclass(order=True)
|
|
class MooncakeEngineKey:
|
|
model_name: str
|
|
world_size: int
|
|
worker_id: int
|
|
chunk_hash: str
|
|
|
|
def __hash__(self):
|
|
return hash((
|
|
self.model_name,
|
|
self.world_size,
|
|
self.worker_id,
|
|
self.chunk_hash,
|
|
))
|
|
|
|
def to_string(self):
|
|
return (f"{self.model_name}@{self.world_size}"
|
|
f"@{self.worker_id}@{self.chunk_hash}")
|
|
|
|
def split_layers(self, num_layers: int) -> List["LayerMooncakeEngineKey"]:
|
|
"""Split the key into multiple keys for each layer"""
|
|
keys = []
|
|
for layer_id in range(num_layers):
|
|
keys.append(
|
|
LayerMooncakeEngineKey(
|
|
self.model_name,
|
|
self.world_size,
|
|
self.worker_id,
|
|
self.chunk_hash,
|
|
layer_id,
|
|
))
|
|
return keys
|
|
|
|
def to_dict(self):
|
|
# Note(Kuntai): this is used for serializing CacheEngineKey via msgpack.
|
|
return {
|
|
"__type__": "CacheEngineKey",
|
|
"model_name": self.model_name,
|
|
"world_size": self.world_size,
|
|
"worker_id": self.worker_id,
|
|
"chunk_hash": self.chunk_hash,
|
|
}
|
|
|
|
@staticmethod
|
|
def from_dict(d):
|
|
return MooncakeEngineKey(
|
|
model_name=d["model_name"],
|
|
world_size=d["world_size"],
|
|
worker_id=d["worker_id"],
|
|
chunk_hash=d["chunk_hash"],
|
|
)
|
|
|
|
|
|
@dataclass(order=True)
|
|
class LayerMooncakeEngineKey(MooncakeEngineKey):
|
|
"""A key for the layer cache engine"""
|
|
|
|
layer_id: int
|
|
|
|
def __hash__(self):
|
|
return hash((
|
|
self.model_name,
|
|
self.world_size,
|
|
self.worker_id,
|
|
self.chunk_hash,
|
|
self.layer_id,
|
|
))
|
|
|
|
def to_string(self):
|
|
return (f"{self.model_name}@{self.world_size}"
|
|
f"@{self.worker_id}@{self.chunk_hash}@{self.layer_id}")
|
|
|
|
|
|
class ChunkedTokenDatabase():
|
|
|
|
def __init__(
|
|
self,
|
|
metadata: MooncakeEngineMetadata,
|
|
):
|
|
self.metadata = metadata
|
|
|
|
def _make_key_by_hash(self,
|
|
chunk_hash: str,
|
|
layer_id: Optional[int] = None):
|
|
assert self.metadata is not None
|
|
return MooncakeEngineKey(
|
|
self.metadata.model_name,
|
|
self.metadata.world_size,
|
|
self.metadata.worker_id,
|
|
chunk_hash,
|
|
)
|
|
|
|
def _hash(
|
|
self,
|
|
tokens: Union[torch.Tensor, List[int]],
|
|
prefix_hash: str,
|
|
) -> str:
|
|
# TODO: change it to a more efficient hash function
|
|
if isinstance(tokens, torch.Tensor):
|
|
tokens_bytes = tokens.cpu().to(torch.uint32).numpy().tobytes()
|
|
elif isinstance(tokens, list):
|
|
tokens_bytes = array.array("I", tokens).tobytes()
|
|
return hashlib.sha256(prefix_hash.encode("ascii") +
|
|
tokens_bytes).hexdigest()
|
|
|
|
def _chunk_tokens(
|
|
self,
|
|
tokens: Union[torch.Tensor, List[int]],
|
|
) -> Iterable[Union[torch.Tensor, List[int]]]:
|
|
"""
|
|
Chunk the tokens into chunks of size self.metadata.block_size.
|
|
|
|
:param tokens: the input tokens, with shape [seq_len]
|
|
device: the target device after chunking
|
|
|
|
:return: a generator of chunks of tokens, each with
|
|
shape [metadata.block_size]
|
|
"""
|
|
for i in range(0, len(tokens), self.metadata.block_size):
|
|
yield tokens[i:i + self.metadata.block_size]
|
|
|
|
def _prefix_hash(
|
|
self,
|
|
token_chunks: Iterable[Union[torch.Tensor, List[int]]],
|
|
) -> Iterable[str]:
|
|
prefix_hash = ''
|
|
for token_chunk in token_chunks:
|
|
prefix_hash = self._hash(token_chunk, prefix_hash)
|
|
yield prefix_hash
|
|
|
|
def process_tokens(
|
|
self,
|
|
tokens: Union[torch.Tensor, List[int]],
|
|
mask: Optional[torch.Tensor] = None,
|
|
) -> Iterable[Tuple[int, int, MooncakeEngineKey]]:
|
|
"""Process the tokens and return the corresponding cache engine keys.
|
|
|
|
:param Union[torch.Tensor, List[int]] tokens: The tokens to process.
|
|
|
|
:param Optional[torch.Tensor] mask: The mask for the tokens. Should
|
|
have the same length as tokens. And the mask should ALWAYS be like
|
|
FFFFFTTTTTTT, where True means the tokens needs to be matched,
|
|
and the Falses will ALWAYS be at the PREFIX of the tensor.
|
|
|
|
:param bool make_key: Whether to make the cache engine key or not.
|
|
If False, the hash value will be returned instead.
|
|
|
|
:returns: A iterable of tuples with three elements. The first element
|
|
is the start index of the tokens for the key. The second element
|
|
is the end index of the tokens for the key. The third element is
|
|
the cache engine key (or hash) for the tokens.
|
|
|
|
:raises: ValueError if the number of Falses in the mask is not a
|
|
multiple of the chunk size.
|
|
"""
|
|
if mask is not None:
|
|
num_falses = mask.numel() - mask.long().sum().item()
|
|
else:
|
|
num_falses = 0
|
|
|
|
if num_falses % self.metadata.block_size != 0:
|
|
raise ValueError(
|
|
"The number of Falses in the mask is not a multiple of the chunk size."
|
|
)
|
|
total_len = len(tokens)
|
|
|
|
token_chunks = self._chunk_tokens(tokens)
|
|
prefix_hashes = self._prefix_hash(token_chunks)
|
|
|
|
start_idx = 0
|
|
for chunk_id, hash_val in enumerate(prefix_hashes):
|
|
start_idx = chunk_id * self.metadata.block_size
|
|
end_idx = min(start_idx + self.metadata.block_size, total_len)
|
|
if start_idx < num_falses:
|
|
continue
|
|
else:
|
|
yield start_idx, end_idx, self._make_key_by_hash(hash_val)
|
|
|
|
|
|
@dataclass
|
|
class LoadSpec:
|
|
# Number of tokens cached in vLLM
|
|
vllm_cached_tokens: int
|
|
# Number of tokens that are cached in mooncake
|
|
mooncake_cached_tokens: int
|
|
# Whether the scheduler allow us to load the tokens
|
|
can_load: bool
|
|
|
|
|
|
@dataclass
|
|
class SaveSpec:
|
|
# Skip already saved tokens
|
|
skip_leading_tokens: int
|
|
# Whether the scheduler allow us to save the tokens
|
|
can_save: bool
|
|
|
|
|
|
@dataclass
|
|
class RequestTracker:
|
|
# Request id
|
|
req_id: str
|
|
|
|
# The token ids that has been scheduled so far
|
|
token_ids: list[int]
|
|
|
|
# The block ids that has been allocated so far
|
|
# NOTE: allocated blocks could be more than the number of tokens
|
|
# FIXME: need to check whether the block ids will be changed after
|
|
# preemption
|
|
allocated_block_ids: list[int]
|
|
|
|
# The number of tokens that has been savd
|
|
num_saved_tokens: int = 0
|
|
|
|
@staticmethod
|
|
def from_new_request(
|
|
new_request: "NewRequestData",
|
|
num_tokens_to_compute: int,
|
|
) -> "RequestTracker":
|
|
"""Create the request tracker from a new request.
|
|
|
|
Args:
|
|
new_request (NewRequestData): the new request data.
|
|
num_tokens_to_compute (int): the number of tokens that will
|
|
be 'computed', including the `num_computed_tokens` (vLLM's
|
|
local cache hit) and new tokens that will be scheduled.
|
|
|
|
"""
|
|
# vLLM 0.9.0 update: request.block_ids changed from list[int] to
|
|
# list[list[int]]
|
|
# Need to check the type of request.block_ids
|
|
|
|
unfolded_block_ids = []
|
|
|
|
if not isinstance(new_request.block_ids[0], list):
|
|
unfolded_block_ids = new_request.block_ids.copy()
|
|
else:
|
|
unfolded_block_ids = new_request.block_ids[0].copy()
|
|
|
|
return RequestTracker(
|
|
req_id=new_request.req_id,
|
|
token_ids=new_request.prompt_token_ids[:num_tokens_to_compute].
|
|
copy(),
|
|
allocated_block_ids=unfolded_block_ids,
|
|
num_saved_tokens=0,
|
|
)
|
|
|
|
def update(
|
|
self,
|
|
new_token_ids: list[int],
|
|
new_block_ids: Union[tuple[list[int], ...], list[int]],
|
|
) -> None:
|
|
"""Update the request tracker when a running request is
|
|
scheduled again
|
|
"""
|
|
|
|
self.token_ids.extend(new_token_ids)
|
|
|
|
if len(new_block_ids) == 0:
|
|
new_block_ids = []
|
|
elif isinstance(new_block_ids, tuple):
|
|
new_block_ids = new_block_ids[0]
|
|
elif isinstance(new_block_ids, list):
|
|
pass
|
|
else:
|
|
raise ValueError(
|
|
f"Unsupported new_block_ids type {type(new_block_ids)}")
|
|
self.allocated_block_ids.extend(new_block_ids)
|
|
|
|
|
|
@dataclass
|
|
class ReqMeta:
|
|
# Request id
|
|
req_id: str
|
|
# Request tokens
|
|
token_ids: torch.Tensor
|
|
|
|
block_ids: list[int]
|
|
# # Slot mapping if exchange for block_id
|
|
# slot_mapping: torch.Tensor
|
|
# Skip save or not
|
|
save_spec: Optional[SaveSpec] = None
|
|
# load_spec
|
|
load_spec: Optional[LoadSpec] = None
|
|
|
|
is_last_chunk: Optional[bool] = None
|
|
|
|
@staticmethod
|
|
def from_request_tracker(
|
|
tracker: RequestTracker,
|
|
block_size: int,
|
|
load_spec: Optional[LoadSpec] = None,
|
|
skip_save: Optional[bool] = False,
|
|
is_last_chunk: Optional[bool] = None,
|
|
discard_partial_chunks: bool = True,
|
|
) -> Optional["ReqMeta"]:
|
|
"""Create the request metadata from a request tracker.
|
|
|
|
Args:
|
|
tracker (RequestTracker): the request tracker.
|
|
block_size (int): the block size in vLLM.
|
|
load_spec (Optional[LoadSpec]): the load spec for KV cache loading.
|
|
skip_save (bool): whether to skip the save operation.
|
|
discard_partial_chunks (bool): whether to discard partial chunks.
|
|
|
|
Returns:
|
|
the request metadata if we need to perform load/save
|
|
operations, None otherwise.
|
|
"""
|
|
input_token_ids = tracker.token_ids
|
|
input_token_len = len(input_token_ids)
|
|
|
|
# For save operation: do not save if the following condition is met
|
|
# 1. has already been saved before (num_saved_tokens > 0)
|
|
# 2. number of unsaved tokens is not reached the chunk boundary
|
|
skip_leading_tokens = tracker.num_saved_tokens
|
|
chunk_boundary = (cdiv(tracker.num_saved_tokens + 1, block_size) *
|
|
block_size if discard_partial_chunks else 0)
|
|
# Calculate number of tokens to save based on discard_partial_chunks
|
|
# setting
|
|
num_tokens_to_save = ((input_token_len // block_size * block_size)
|
|
if discard_partial_chunks else input_token_len)
|
|
|
|
skip_save = skip_save or num_tokens_to_save < chunk_boundary
|
|
if skip_save and load_spec is None:
|
|
return None
|
|
|
|
# If we need to save, update the number of saved tokens
|
|
if not skip_save:
|
|
tracker.num_saved_tokens = num_tokens_to_save
|
|
save_spec = SaveSpec(skip_leading_tokens, not skip_save)
|
|
|
|
# Calculate the token ids and slot mappings for load and save
|
|
# OPTIMIZATION: pre-allocate the buffer for token ids and block ids
|
|
token_ids = torch.tensor(input_token_ids)[:num_tokens_to_save]
|
|
|
|
# # For load operation: check whether the request is scheduled to load
|
|
if load_spec is not None and load_spec.can_load:
|
|
logger.debug(
|
|
"Scheduled to load %d tokens for request %s",
|
|
load_spec.mooncake_cached_tokens,
|
|
tracker.req_id,
|
|
)
|
|
else:
|
|
# Do not load if not in `can_load` state
|
|
load_spec = None
|
|
logger.debug(
|
|
f"request:{tracker.req_id}, meta save spec:{save_spec}, meta load spec:{load_spec}"
|
|
)
|
|
return ReqMeta(
|
|
req_id=tracker.req_id,
|
|
token_ids=token_ids,
|
|
block_ids=tracker.allocated_block_ids,
|
|
save_spec=save_spec,
|
|
load_spec=load_spec,
|
|
is_last_chunk=is_last_chunk,
|
|
)
|
|
|
|
|
|
class MooncakeConnectorMetadata(KVConnectorMetadata):
|
|
|
|
def __init__(self, unfinished_request_ids):
|
|
self.requests = []
|
|
self.unfinished_request_ids = unfinished_request_ids
|
|
|
|
def add_request(self, req_meta: ReqMeta) -> None:
|
|
"""Add a request to the metadata.
|
|
|
|
Args:
|
|
req_meta (ReqMeta): the request metadata.
|
|
"""
|
|
self.requests.append(req_meta)
|
|
|
|
|
|
@dataclass
|
|
class LasyerMultiBlockReqMeta:
|
|
req_id: str
|
|
keys: List[LayerMooncakeEngineKey]
|
|
starts: List[int]
|
|
ends: list[int]
|
|
block_ids: list[int]
|
|
layer_id: int
|
|
|
|
|
|
@dataclass
|
|
class MooncakeStoreConfig:
|
|
local_hostname: str
|
|
metadata_server: str
|
|
global_segment_size: Union[int, str]
|
|
local_buffer_size: int
|
|
protocol: str
|
|
device_name: str
|
|
master_server_address: str
|
|
use_ascend_direct: bool
|
|
|
|
@staticmethod
|
|
def from_file(file_path: str) -> "MooncakeStoreConfig":
|
|
with open(file_path) as file:
|
|
config = json.load(file)
|
|
return MooncakeStoreConfig(
|
|
local_hostname=config.get("local_hostname"),
|
|
metadata_server=config.get("metadata_server"),
|
|
global_segment_size=_parse_global_segment_size(
|
|
config.get("global_segment_size",
|
|
DEFAULT_GLOBAL_SEGMENT_SIZE)),
|
|
local_buffer_size=(config.get("local_buffer_size",
|
|
DEFAULT_LOCAL_BUFFER_SIZE)),
|
|
protocol=config.get("protocol", "tcp"),
|
|
device_name=config.get("device_name", ""),
|
|
master_server_address=config.get("master_server_address"),
|
|
use_ascend_direct=config.get("use_ascend_direct", False))
|
|
|
|
@staticmethod
|
|
def load_from_env() -> "MooncakeStoreConfig":
|
|
config_path = os.getenv("MOONCAKE_CONFIG_PATH")
|
|
if not config_path:
|
|
raise ValueError(
|
|
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set.")
|
|
return MooncakeStoreConfig.from_file(config_path)
|
|
|
|
|
|
def _parse_global_segment_size(value) -> int:
|
|
"""
|
|
Parse storage size strings with support for units: GB, MB, KB, B
|
|
|
|
Args:
|
|
value: Input value (int, str, or other convertible types)
|
|
|
|
Returns:
|
|
int: Size in bytes
|
|
|
|
Raises:
|
|
ValueError: For invalid format, missing number, or negative values
|
|
TypeError: For unsupported input types
|
|
"""
|
|
|
|
if isinstance(value, int):
|
|
return value
|
|
elif not isinstance(value, str):
|
|
try:
|
|
return int(value)
|
|
except (TypeError, ValueError) as e:
|
|
raise TypeError(
|
|
f"Unsupported type for global_segment_size: {type(value)}"
|
|
) from e
|
|
|
|
cleaned_input = value.strip().lower()
|
|
if not cleaned_input:
|
|
raise ValueError("global segment size cannot be empty.")
|
|
|
|
UNIT_MULTIPLIERS = {
|
|
'gb': 1024**3, # 1 GB = 1024^3 bytes
|
|
'mb': 1024**2, # 1 MB = 1024^2 bytes
|
|
'kb': 1024, # 1 KB = 1024 bytes
|
|
'b': 1 # 1 B = 1 byte
|
|
}
|
|
pattern = r'^\s*([\d.]+)\s*(gb|mb|kb|b)?\s*$'
|
|
match = re.match(pattern, cleaned_input)
|
|
|
|
if not match:
|
|
raise ValueError(f"Invalid format: '{value}'")
|
|
|
|
number_str = match.group(1)
|
|
unit = match.group(2) or 'b'
|
|
|
|
multiplier = UNIT_MULTIPLIERS[unit]
|
|
return _convert_to_bytes(number_str, multiplier, value)
|
|
|
|
|
|
def _convert_to_bytes(number_str: str, multiplier: int,
|
|
original_input: str) -> int:
|
|
"""
|
|
Convert numeric string to byte count
|
|
|
|
Args:
|
|
number_str: Numeric portion of input
|
|
multiplier: Unit conversion factor
|
|
original_input: Original input string (for error messages)
|
|
|
|
Returns:
|
|
int: Byte count
|
|
|
|
Raises:
|
|
ValueError: For invalid numbers or negative results
|
|
"""
|
|
try:
|
|
numeric_value = float(number_str)
|
|
except ValueError:
|
|
raise ValueError(
|
|
f"Invalid numeric value '{number_str}' in: '{original_input}'")
|
|
# Calculate byte count
|
|
try:
|
|
byte_count = int(numeric_value * multiplier)
|
|
except OverflowError:
|
|
raise ValueError(f"Storage size too large: '{original_input}'")
|
|
return byte_count
|