[Feature][main]reconstruction kvpool connector to ascend connector (#4438)

### What this PR does / why we need it?
1.In short, we renamed the existing MooncakeStoreConnector to
AscendStoreConnector and extracted the storage engine interaction logic
into a new Backend class.
Associated RFC:https://github.com/vllm-project/vllm-ascend/issues/4329
2.Fixed the issue where the number of input parameters for the connector
was incorrect, introduced in vllm 0.11.2
### Does this PR introduce _any_ user-facing change?
change MooncakeStoreConnector to AscendStoreConnector
### How was this patch tested?

- vLLM version: v0.11.2

---------

Signed-off-by: fems14 <1804143737@qq.com>
This commit is contained in:
fems14
2025-11-28 18:08:37 +08:00
committed by GitHub
parent 554f16ae1f
commit 5447a039b9
25 changed files with 1489 additions and 1511 deletions

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,194 @@
import threading
from typing import Any, Optional
import torch
import zmq
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.forward_context import ForwardContext
from vllm.utils import logger
from vllm.utils.network_utils import make_zmq_socket
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request
from vllm.v1.serial_utils import MsgpackDecoder
from vllm_ascend.distributed.kvpool.pool_scheduler import (
KVPoolScheduler, get_zmq_rpc_path_lookup)
from vllm_ascend.distributed.kvpool.pool_worker import KVPoolWorker
class AscendStoreConnector(KVConnectorBase_V1):
def __init__(self,
vllm_config: VllmConfig,
role: KVConnectorRole,
kv_cache_config: Optional[KVCacheConfig] = None):
super().__init__(vllm_config=vllm_config,
role=role,
kv_cache_config=kv_cache_config)
self.kv_role = vllm_config.kv_transfer_config.kv_role
self.use_layerwise = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"use_layerwise", False)
connector_name = vllm_config.kv_transfer_config.kv_connector
if connector_name == "MooncakeConnectorStoreV1":
logger.warning(
"It is recommended to use the AscendStoreConnector, as the MoonCakeStoreConnector will be removed in the future."
)
self.kv_caches: dict[str, torch.Tensor] = {}
self._block_size = vllm_config.cache_config.block_size
self.sended_but_unfinished_reqs: set[str] = set()
if role == KVConnectorRole.SCHEDULER:
self.connector_scheduler = KVPoolScheduler(vllm_config,
self.use_layerwise)
else:
self.connector_worker = KVPoolWorker(
vllm_config,
self.use_layerwise,
)
assert self.connector_worker is not None
if vllm_config.parallel_config.rank == 0:
self.lookup_server = LookupKeyServer(self.connector_worker,
vllm_config,
self.use_layerwise)
############################################################
# Scheduler Side Methods
############################################################
def get_num_new_matched_tokens(
self, request: "Request",
num_computed_tokens: int) -> tuple[int, bool]:
assert self.connector_scheduler is not None
return self.connector_scheduler.get_num_new_matched_tokens(
request, num_computed_tokens)
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
assert self.connector_scheduler is not None
return self.connector_scheduler.update_state_after_alloc(
request, blocks, num_external_tokens)
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
assert self.connector_scheduler is not None
return self.connector_scheduler.build_connector_meta(scheduler_output)
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
assert self.connector_scheduler is not None
return self.connector_scheduler.request_finished(request, block_ids)
############################################################
# Worker Side Methods
############################################################
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
assert self.connector_worker is not None
self.connector_worker.register_kv_caches(kv_caches)
def start_load_kv(self, forward_context: "ForwardContext",
**kwargs) -> None:
assert self.connector_worker is not None
self.connector_worker.start_load_kv(self._get_connector_metadata())
def wait_for_layer_load(self, layer_name: str) -> None:
if not self.use_layerwise:
return
self.connector_worker.wait_for_layer_load()
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", **kwargs) -> None:
if not self.use_layerwise:
return
if self.kv_role == "kv_consumer":
# Don't do save if the role is kv_consumer
return
self.connector_worker.save_kv_layer(self._get_connector_metadata())
def wait_for_save(self):
if self.kv_role == "kv_consumer":
# Don't do save if the role is kv_consumer
return
if self.use_layerwise:
return
self.connector_worker.wait_for_save(self._get_connector_metadata())
def get_finished(self,
finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
"""Get the finished recving and sending requests."""
assert self.connector_worker is not None
meta = self._get_connector_metadata()
done_sending, done_recving = self.connector_worker.get_finished()
sended_and_finished: set[str] = set()
for item in list(self.sended_but_unfinished_reqs):
if item not in meta.unfinished_request_ids:
sended_and_finished.add(item)
self.sended_but_unfinished_reqs.remove(item)
for item in done_sending:
if item in meta.unfinished_request_ids:
self.sended_but_unfinished_reqs.add(item)
else:
sended_and_finished.add(item)
return sended_and_finished, done_recving
class LookupKeyServer:
def __init__(
self,
pool_worker: KVPoolWorker,
vllm_config: "VllmConfig",
use_layerwise: bool,
):
self.decoder = MsgpackDecoder()
self.decoder_tensor = MsgpackDecoder(torch.Tensor)
self.ctx = zmq.Context() # type: ignore[attr-defined]
socket_path = get_zmq_rpc_path_lookup(vllm_config)
self.socket = make_zmq_socket(
self.ctx,
socket_path,
zmq.REP, # type: ignore[attr-defined]
bind=True,
)
self.pool_worker = pool_worker
self.running = True
self.use_layerwise = use_layerwise
def process_request():
while self.running:
all_frames = self.socket.recv_multipart(copy=False)
token_len = int.from_bytes(all_frames[0], byteorder="big")
hash_frames = all_frames[1:]
hashes_str = self.decoder.decode(hash_frames)
result = self.pool_worker.lookup_scheduler(
token_len, hashes_str, self.use_layerwise)
response = result.to_bytes(4, "big")
self.socket.send(response)
self.thread = threading.Thread(target=process_request, daemon=True)
self.thread.start()
def close(self):
self.socket.close(linger=0)
# TODO: close the thread!

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,29 @@
from abc import ABC, abstractmethod
from vllm.config import ParallelConfig
class Backend(ABC):
def __init__(self, parallel_config: ParallelConfig):
pass
def set_device(self):
pass
def register_buffer(self, ptrs: list[int], lengths: list[int]):
pass
@abstractmethod
def exists(self, keys: list[str]) -> list[int]:
pass
@abstractmethod
def put(self, keys: list[str], addrs: list[list[int]],
sizes: list[list[int]]):
pass
@abstractmethod
def get(self, keys: list[str], addrs: list[list[int]],
sizes: list[list[int]]):
pass

View File

@@ -0,0 +1,74 @@
# Standard
from enum import Enum
import torch
from vllm.config import ParallelConfig
from vllm.utils import logger
from vllm_ascend.distributed.kvpool.backend.backend import Backend
class MmcDirect(Enum):
COPY_L2G = 0
COPY_G2L = 1
COPY_G2H = 2
COPY_H2G = 3
class MemcacheBackend(Backend):
def __init__(self, parallel_config: ParallelConfig):
try:
from memcache import DistributedObjectStore # type: ignore
except ImportError as e:
raise ImportError(
"Please install memcache by following the instructions at "
"https://gitee.com/ascend/memfabric_hybrid " # noqa: E501
"to run vLLM with MemcacheConnector.") from e
try:
self.rank = parallel_config.rank
self.store = DistributedObjectStore()
res = self.store.init(self.rank)
assert res == 0
except ValueError as e:
logger.error("Configuration loading failed: %s", e)
raise
except Exception as exc:
logger.error(
"An error occurred while loading the configuration: %s", exc)
raise
def set_device(self):
device = torch.device(f"npu:{self.rank}")
torch.npu.set_device(device)
def register_buffer(self, ptrs: list[int], sizes: list[int]):
for ptr, size in zip(ptrs, sizes):
ret_value = self.store.register_buffer(ptr, size)
if ret_value != 0:
raise RuntimeError("Memcache memory registration failed.")
def exists(self, keys: list[str]) -> list[int]:
return self.store.batch_is_exist(keys)
def get(self, key: list[str], addr: list[list[int]],
size: list[list[int]]):
try:
res = self.store.batch_get_into_layers(key, addr, size,
MmcDirect.COPY_G2L.value)
for value in res:
if value != 0:
logger.error(f"Failed to get key {key},res:{res}")
except Exception as e:
logger.error(f"Failed to get key {key}. {e}")
def put(self, key: list[str], addr: list[list[int]],
size: list[list[int]]):
try:
res = self.store.batch_put_from_layers(key, addr, size,
MmcDirect.COPY_L2G.value)
for value in res:
if value != 0:
logger.error(f"Failed to get key {key},res:{res}")
except Exception as e:
logger.error(f"Failed to put key {key},error:{e}")

View File

@@ -0,0 +1,188 @@
# Standard
import json
import os
import re
from dataclasses import dataclass
from typing import Union
# Third Party
from vllm.config import ParallelConfig
from vllm.utils import logger
from vllm.utils.network_utils import get_ip
from vllm_ascend.distributed.kvpool.backend.backend import Backend
from vllm_ascend.distributed.mooncake_transfer_engine import global_te
DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB
DEFAULT_LOCAL_BUFFER_SIZE = 1073741824 # 1.0 GiB
class MooncakeBackend(Backend):
def __init__(self, parallel_config: ParallelConfig):
try:
from mooncake.store import MooncakeDistributedStore # type: ignore
except ImportError as e:
raise ImportError(
"Please install mooncake by following the instructions at "
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
"to run vLLM with MooncakeConnector.") from e
self.config = MooncakeStoreConfig.load_from_env()
self.store = MooncakeDistributedStore()
if self.config.protocol == "ascend":
local_hostname = get_ip()
transfer_engine = global_te.get_transfer_engine(local_hostname,
device_name=None)
self.local_seg = local_hostname + ":" + str(
transfer_engine.get_rpc_port())
ret = self.store.setup(self.local_seg, self.config.metadata_server,
self.config.global_segment_size,
self.config.local_buffer_size,
self.config.protocol,
self.config.device_name,
self.config.master_server_address,
transfer_engine.get_engine())
if ret != 0:
msg = "Initialize mooncake failed."
logger.error(msg)
raise RuntimeError(msg)
def register_buffer(self, ptrs: list[int], lengths: list[int]):
global_te.register_buffer(ptrs, lengths)
def exists(self, keys: list[str]) -> list[int]:
return self.store.batch_is_exist(keys)
def put(self, keys: list[str], addrs: list[list[int]],
sizes: list[list[int]]):
try:
res = self.store.batch_put_from_multi_buffers(keys, addrs, sizes)
for value in res:
if value < 0:
logger.error(f"Failed to put key {keys},res:{res}")
except Exception as e:
logger.error(f"Failed to put key {keys},error:{e}")
def get(self, keys: list[str], addrs: list[list[int]],
sizes: list[list[int]]):
try:
res = self.store.batch_get_into_multi_buffers(keys, addrs, sizes)
for value in res:
if value < 0:
logger.error(f"Failed to get key {keys}, res:{res}")
except Exception as e:
logger.error(f"Failed to get key {keys}, error:{e}")
@dataclass
class MooncakeStoreConfig:
local_hostname: str
metadata_server: str
global_segment_size: Union[int, str]
local_buffer_size: int
protocol: str
device_name: str
master_server_address: str
use_ascend_direct: bool
@staticmethod
def from_file(file_path: str) -> "MooncakeStoreConfig":
with open(file_path) as file:
config = json.load(file)
return MooncakeStoreConfig(
local_hostname=config.get("local_hostname"),
metadata_server=config.get("metadata_server"),
global_segment_size=_parse_global_segment_size(
config.get("global_segment_size",
DEFAULT_GLOBAL_SEGMENT_SIZE)),
local_buffer_size=(config.get("local_buffer_size",
DEFAULT_LOCAL_BUFFER_SIZE)),
protocol=config.get("protocol", "tcp"),
device_name=config.get("device_name", ""),
master_server_address=config.get("master_server_address"),
use_ascend_direct=config.get("use_ascend_direct", False))
@staticmethod
def load_from_env() -> "MooncakeStoreConfig":
config_path = os.getenv("MOONCAKE_CONFIG_PATH")
if not config_path:
raise ValueError(
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set.")
return MooncakeStoreConfig.from_file(config_path)
def _parse_global_segment_size(value) -> int:
"""
Parse storage size strings with support for units: GB, MB, KB, B
Args:
value: Input value (int, str, or other convertible types)
Returns:
int: Size in bytes
Raises:
ValueError: For invalid format, missing number, or negative values
TypeError: For unsupported input types
"""
if isinstance(value, int):
return value
elif not isinstance(value, str):
try:
return int(value)
except (TypeError, ValueError) as e:
raise TypeError(
f"Unsupported type for global_segment_size: {type(value)}"
) from e
cleaned_input = value.strip().lower()
if not cleaned_input:
raise ValueError("global segment size cannot be empty.")
UNIT_MULTIPLIERS = {
'gb': 1024**3, # 1 GB = 1024^3 bytes
'mb': 1024**2, # 1 MB = 1024^2 bytes
'kb': 1024, # 1 KB = 1024 bytes
'b': 1 # 1 B = 1 byte
}
pattern = r'^\s*([\d.]+)\s*(gb|mb|kb|b)?\s*$'
match = re.match(pattern, cleaned_input)
if not match:
raise ValueError(f"Invalid format: '{value}'")
number_str = match.group(1)
unit = match.group(2) or 'b'
multiplier = UNIT_MULTIPLIERS[unit]
return _convert_to_bytes(number_str, multiplier, value)
def _convert_to_bytes(number_str: str, multiplier: int,
original_input: str) -> int:
"""
Convert numeric string to byte count
Args:
number_str: Numeric portion of input
multiplier: Unit conversion factor
original_input: Original input string (for error messages)
Returns:
int: Byte count
Raises:
ValueError: For invalid numbers or negative results
"""
try:
numeric_value = float(number_str)
except ValueError:
raise ValueError(
f"Invalid numeric value '{number_str}' in: '{original_input}'")
# Calculate byte count
try:
byte_count = int(numeric_value * multiplier)
except OverflowError:
raise ValueError(f"Storage size too large: '{original_input}'")
return byte_count

View File

@@ -0,0 +1,364 @@
from dataclasses import dataclass
from typing import Iterable, List, Optional, Tuple, Union
from vllm.distributed.kv_transfer.kv_connector.v1.base import \
KVConnectorMetadata
from vllm.utils import logger
from vllm.utils.math_utils import cdiv
from vllm.v1.core.kv_cache_utils import BlockHash
from vllm.v1.core.sched.output import NewRequestData
#Parameters related to the key
@dataclass
class KeyMetadata:
"""name of the LLM model"""
model_name: str
""" worker id when running under a distributed setting """
head_or_tp_rank: int
@dataclass(order=True)
class PoolKey:
key_metadata: KeyMetadata
chunk_hash: str
def __hash__(self):
return hash((
self.key_metadata.model_name,
self.key_metadata.head_or_tp_rank,
self.chunk_hash,
))
def to_string(self):
return (
f"{self.key_metadata.model_name}"
f"@head_or_tp_rank:{self.key_metadata.head_or_tp_rank}@{self.chunk_hash}"
)
def split_layers(self, num_layers: int) -> List["LayerPoolKey"]:
"""Split the key into multiple keys for each layer"""
keys = []
for layer_id in range(num_layers):
keys.append(
LayerPoolKey(
self.key_metadata,
self.chunk_hash,
layer_id,
))
return keys
@dataclass(order=True)
class LayerPoolKey(PoolKey):
"""A key for the layer cache engine"""
layer_id: int
def __hash__(self):
return hash((
self.key_metadata.model_name,
self.key_metadata.head_or_tp_rank,
self.chunk_hash,
self.layer_id,
))
def to_string(self):
return (
f"{self.key_metadata.model_name}"
f"@head_or_tp_rank:{self.key_metadata.head_or_tp_rank}@{self.chunk_hash}@{self.layer_id}"
)
class ChunkedTokenDatabase():
def __init__(self, metadata: KeyMetadata, block_size: int, use_mla: bool):
self.metadata = metadata
self.block_size = block_size
self.use_mla = use_mla
self.kv_caches_base_addr: list[int] = []
self.block_len: list[int] = []
def _make_key_by_hash(self,
chunk_hash: str,
layer_id: Optional[int] = None):
assert self.metadata is not None
return PoolKey(
self.metadata,
chunk_hash,
)
def set_kv_caches_base_addr(self, kv_caches_base_addr: list[int]):
self.kv_caches_base_addr = kv_caches_base_addr
def set_block_len(self, block_len: list[int]):
self.block_len = block_len
def prepare_value(self, start: int, end: int, block_ids: list[int]):
addr_list = []
size_list = []
block_id = block_ids[start // self.block_size]
for index, base_addr in enumerate(self.kv_caches_base_addr):
block_len = (self.block_len[index % 2]
if self.use_mla else self.block_len[0])
addr = base_addr + block_id * block_len
length = int(block_len / self.block_size * (end - start))
addr_list.append(addr)
size_list.append(length)
return addr_list, size_list, block_id
def prepare_value_layer(self, start: int, end: int, block_ids: list[int],
layer_id: int):
block_id = block_ids[start // self.block_size]
if self.use_mla:
addr_k = self.kv_caches_base_addr[layer_id *
2] + block_id * self.block_len[0]
addr_v = self.kv_caches_base_addr[layer_id * 2 +
1] + block_id * self.block_len[1]
length_k = int(self.block_len[0] / self.block_size * (end - start))
length_v = int(self.block_len[1] / self.block_size * (end - start))
size_list = [length_k, length_v]
else:
addr_k = self.kv_caches_base_addr[layer_id *
2] + block_id * self.block_len[0]
addr_v = self.kv_caches_base_addr[layer_id * 2 +
1] + block_id * self.block_len[0]
length = int(self.block_len[0] / self.block_size * (end - start))
size_list = [length, length]
addr_list = [addr_k, addr_v]
return addr_list, size_list
def process_tokens(
self,
token_len: int,
block_hashes: Union[list[BlockHash], list[str]],
mask_num: int = 0,
) -> Iterable[Tuple[int, int, PoolKey]]:
"""Process the tokens and return the corresponding cache engine keys.
:param Union[torch.Tensor, List[int]] tokens: The tokens to process.
:param Optional[torch.Tensor] mask: The mask for the tokens. Should
have the same length as tokens. And the mask should ALWAYS be like
FFFFFTTTTTTT, where True means the tokens needs to be matched,
and the Falses will ALWAYS be at the PREFIX of the tensor.
:param bool make_key: Whether to make the cache engine key or not.
If False, the hash value will be returned instead.
:returns: A iterable of tuples with three elements. The first element
is the start index of the tokens for the key. The second element
is the end index of the tokens for the key. The third element is
the cache engine key (or hash) for the tokens.
:raises: ValueError if the number of Falses in the mask is not a
multiple of the chunk size.
"""
if not block_hashes:
return
if not isinstance(block_hashes[0], str):
block_hashes = [
h.hex() # type: ignore[union-attr]
for h in block_hashes
]
start_idx = 0
for chunk_id, hash_val in enumerate(block_hashes):
start_idx = chunk_id * self.block_size
if start_idx >= token_len:
break
end_idx = min(start_idx + self.block_size, token_len)
if start_idx < mask_num:
continue
else:
yield start_idx, end_idx, self._make_key_by_hash(hash_val)
#Parameters related to the connector metadata
@dataclass
class LoadSpec:
# Number of tokens cached in vLLM
vllm_cached_tokens: int
# Number of tokens that are cached in kvpool
kvpool_cached_tokens: int
# Whether the scheduler allow us to load the tokens
can_load: bool
@dataclass
class RequestTracker:
# Request id
req_id: str
# The token ids that has been scheduled so far
token_len: int
# The block ids that has been allocated so far
# NOTE: allocated blocks could be more than the number of tokens
# FIXME: need to check whether the block ids will be changed after
# preemption
allocated_block_ids: list[int]
# The number of tokens that has been savd
num_saved_tokens: int = 0
@staticmethod
def from_new_request(
new_request: "NewRequestData",
num_tokens_to_compute: int,
) -> "RequestTracker":
"""Create the request tracker from a new request.
Args:
new_request (NewRequestData): the new request data.
num_tokens_to_compute (int): the number of tokens that will
be 'computed', including the `num_computed_tokens` (vLLM's
local cache hit) and new tokens that will be scheduled.
"""
unfolded_block_ids = []
if not isinstance(new_request.block_ids[0], list):
unfolded_block_ids = new_request.block_ids.copy()
else:
unfolded_block_ids = new_request.block_ids[0].copy()
return RequestTracker(
req_id=new_request.req_id,
token_len=num_tokens_to_compute,
allocated_block_ids=unfolded_block_ids,
num_saved_tokens=0,
)
def update(
self,
new_token_ids: list[int],
new_block_ids: Union[tuple[list[int], ...], list[int]],
) -> None:
"""Update the request tracker when a running request is
scheduled again
"""
self.token_len = self.token_len + len(new_token_ids)
if len(new_block_ids) == 0:
new_block_ids = []
elif isinstance(new_block_ids, tuple):
new_block_ids = new_block_ids[0]
elif isinstance(new_block_ids, list):
pass
else:
raise ValueError(
f"Unsupported new_block_ids type {type(new_block_ids)}")
self.allocated_block_ids.extend(new_block_ids)
@dataclass
class ReqMeta:
# Request id
req_id: str
# Request tokens
token_len_chunk: int
block_ids: list[int]
block_hashes: list[BlockHash]
can_save: Optional[bool] = None
# load_spec
load_spec: Optional[LoadSpec] = None
is_last_chunk: Optional[bool] = None
@staticmethod
def from_request_tracker(
tracker: RequestTracker,
block_size: int,
load_spec: Optional[LoadSpec] = None,
skip_save: Optional[bool] = False,
block_hashes: list[BlockHash] = [],
is_last_chunk: Optional[bool] = None,
discard_partial_chunks: bool = True,
) -> Optional["ReqMeta"]:
"""Create the request metadata from a request tracker.
Args:
tracker (RequestTracker): the request tracker.
block_size (int): the block size in vLLM.
load_spec (Optional[LoadSpec]): the load spec for KV cache loading.
skip_save (bool): whether to skip the save operation.
discard_partial_chunks (bool): whether to discard partial chunks.
Returns:
the request metadata if we need to perform load/save
operations, None otherwise.
"""
input_token_len = tracker.token_len
# For save operation: do not save if the following condition is met
# 1. has already been saved before (num_saved_tokens > 0)
# 2. number of unsaved tokens is not reached the chunk boundary
chunk_boundary = (cdiv(tracker.num_saved_tokens + 1, block_size) *
block_size if discard_partial_chunks else 0)
# Calculate number of tokens to save based on discard_partial_chunks
# setting
num_tokens_to_save = ((input_token_len // block_size * block_size)
if discard_partial_chunks else input_token_len)
skip_save = skip_save or num_tokens_to_save < chunk_boundary
if skip_save and load_spec is None:
return None
# If we need to save, update the number of saved tokens
if not skip_save:
tracker.num_saved_tokens = num_tokens_to_save
# # For load operation: check whether the request is scheduled to load
if load_spec is not None and load_spec.can_load:
logger.debug(
"Scheduled to load %d tokens for request %s",
load_spec.kvpool_cached_tokens,
tracker.req_id,
)
else:
# Do not load if not in `can_load` state
load_spec = None
logger.debug(
f"request:{tracker.req_id}, meta save spec:{not skip_save}, meta load spec:{load_spec}"
)
return ReqMeta(
req_id=tracker.req_id,
token_len_chunk=num_tokens_to_save,
block_ids=tracker.allocated_block_ids,
can_save=not skip_save,
load_spec=load_spec,
block_hashes=block_hashes,
is_last_chunk=is_last_chunk,
)
class AscendConnectorMetadata(KVConnectorMetadata):
def __init__(self, unfinished_request_ids):
self.requests = []
self.unfinished_request_ids = unfinished_request_ids
def add_request(self, req_meta: ReqMeta) -> None:
"""Add a request to the metadata.
Args:
req_meta (ReqMeta): the request metadata.
"""
self.requests.append(req_meta)
@dataclass
class LasyerMultiBlockReqMeta:
req_id: str
keys: List[LayerPoolKey]
starts: List[int]
ends: list[int]
block_ids: list[int]
layer_id: int
is_last_chunk: bool = True

View File

@@ -0,0 +1,246 @@
import queue
import threading
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Optional
import torch
from vllm.utils import logger
from vllm.v1.core.kv_cache_utils import BlockHash
from vllm_ascend.distributed.kvpool.backend.backend import Backend
# isort: off
from vllm_ascend.distributed.kvpool.config_data import (ChunkedTokenDatabase,
LasyerMultiBlockReqMeta
)
# isort: on
class KVTransferThread(threading.Thread):
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
tp_rank: int, ready_event: threading.Event, name: str):
super().__init__(daemon=True, name=name)
self.m_store = m_store
self.ready_event = ready_event
self.tp_rank = tp_rank
self.token_database = token_database
self.done_task_lock = threading.Lock()
self.request_queue: queue.Queue[Any] = queue.Queue()
# TODO(jianzs): make this configurable
self.executor = ThreadPoolExecutor(max_workers=32)
self.finished_requests: set[str] = set()
def add_request(
self,
req_id: str,
token_len: int,
block_ids: list[int],
block_hashes: list[BlockHash],
mask_num: int = 0,
is_last_chunk: Optional[bool] = None,
) -> torch.Tensor:
req = ({
"req_id": req_id,
"token_len": token_len,
"block_ids": block_ids,
"block_hashes": block_hashes,
"mask_num": mask_num,
"is_last_chunk": is_last_chunk,
})
self.request_queue.put(req)
def get_and_clear_finished_requests(self) -> set[str]:
"""
Get and clear the requests that have been completed.
Returns:
A set of request IDs that have been completed.
"""
with self.done_task_lock:
finished_requests = self.finished_requests.copy()
self.finished_requests.clear()
return finished_requests
def set_finished_request(self, req_id):
with self.done_task_lock:
self.finished_requests.add(req_id)
def run(self):
"""Run the thread to handle KV cache transfer requests."""
self.m_store.set_device()
self.ready_event.set()
while True:
try:
request_data = self.request_queue.get()
if request_data is None:
logger.warning("Received a None request!")
self.request_queue.task_done()
continue
self._handle_request(request_data)
except Exception as e:
logger.error(f"Error in KVCacheTransferThread: {e}")
def _handle_request(self, req_meta: dict[str, Any]):
pass
class KVCacheStoreSendingThread(KVTransferThread):
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
tp_rank: int, put_step: int, ready_event: threading.Event):
super().__init__(m_store,
token_database,
tp_rank,
ready_event,
name="KVCacheSendingThread")
self.put_step = put_step
def _handle_request(self, req_meta: dict[str, Any]):
token_len = req_meta["token_len"]
mask_num = req_meta["mask_num"]
block_ids = req_meta["block_ids"]
block_hashes = req_meta["block_hashes"]
req_id = req_meta["req_id"]
is_last_chunk = req_meta["is_last_chunk"]
addr_list = []
size_list = []
key_list = []
for start, end, key in self.token_database.process_tokens(
token_len, block_hashes, mask_num):
addr, size, _ = self.token_database.prepare_value(
start, end, block_ids)
key_list.append(key.to_string())
addr_list.append(addr)
size_list.append(size)
key_list_tp = key_list[self.tp_rank % self.put_step::self.put_step]
addr_list_tp = addr_list[self.tp_rank % self.put_step::self.put_step]
size_list_tp = size_list[self.tp_rank % self.put_step::self.put_step]
if key_list_tp:
torch.npu.current_stream().synchronize()
self.m_store.put(key_list_tp, addr_list_tp, size_list_tp)
if is_last_chunk:
self.set_finished_request(req_id)
self.request_queue.task_done()
class KVCacheStoreRecvingThread(KVTransferThread):
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
tp_rank: int, ready_event: threading.Event):
super().__init__(m_store,
token_database,
tp_rank,
ready_event,
name="KVCacheStoreRecvingThread")
def _handle_request(self, req_meta: dict[str, Any]):
token_len = req_meta["token_len"]
mask_num = req_meta["mask_num"]
block_ids = req_meta["block_ids"]
req_id = req_meta["req_id"]
block_hashes = req_meta["block_hashes"]
addr_list = []
size_list = []
key_list = []
for start, end, key in self.token_database.process_tokens(
token_len, block_hashes, mask_num):
addr, size, _ = self.token_database.prepare_value(
start, end, block_ids)
key_list.append(key.to_string())
addr_list.append(addr)
size_list.append(size)
key_list_c = key_list[self.tp_rank %
len(key_list):] + key_list[:self.tp_rank %
len(key_list)]
addr_list_c = addr_list[self.tp_rank %
len(addr_list):] + addr_list[:self.tp_rank %
len(addr_list)]
size_list_c = size_list[self.tp_rank %
len(size_list):] + size_list[:self.tp_rank %
len(size_list)]
self.m_store.get(key_list_c, addr_list_c, size_list_c)
self.set_finished_request(req_id)
self.request_queue.task_done()
class KVCacheStoreLayerSendingThread(KVTransferThread):
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
tp_rank: int, put_step: int, ready_event: threading.Event,
num_layers: int):
super().__init__(m_store,
token_database,
tp_rank,
ready_event,
name="KVCacheStoreLayerSendingThread")
self.final_layer_id = num_layers - 1
self.put_step = put_step
def add_request( # type: ignore[override]
self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor:
self.request_queue.put(req_meta)
def _handle_request( # type: ignore[override]
self, req_meta: LasyerMultiBlockReqMeta):
addr_list = []
size_list = []
key_list = []
for index, key in enumerate(req_meta.keys):
addr, size = self.token_database.prepare_value_layer(
req_meta.starts[index], req_meta.ends[index],
req_meta.block_ids, req_meta.layer_id)
key_list.append(key.to_string())
addr_list.append(addr)
size_list.append(size)
key_list_tp = key_list[self.tp_rank % self.put_step::self.put_step]
addr_list_tp = addr_list[self.tp_rank % self.put_step::self.put_step]
size_list_tp = size_list[self.tp_rank % self.put_step::self.put_step]
if key_list_tp:
torch.npu.current_stream().synchronize()
self.m_store.put(key_list_tp, addr_list_tp, size_list_tp)
if req_meta.layer_id == self.final_layer_id and req_meta.is_last_chunk:
self.set_finished_request(req_meta.req_id)
self.request_queue.task_done()
class KVCacheStoreLayerRecvingThread(KVTransferThread):
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
tp_rank: int, ready_event: threading.Event,
get_event: threading.Event):
super().__init__(m_store,
token_database,
tp_rank,
ready_event,
name="KVCacheStoreLayerRecvingThread")
self.get_event = get_event
def add_request( # type: ignore[override]
self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor:
self.request_queue.put(req_meta)
def _handle_request( # type: ignore[override]
self, req_meta: LasyerMultiBlockReqMeta):
addr_list = []
size_list = []
key_list = []
for index, key in enumerate(req_meta.keys):
addr, size = self.token_database.prepare_value_layer(
req_meta.starts[index], req_meta.ends[index],
req_meta.block_ids, req_meta.layer_id)
key_list.append(key.to_string())
addr_list.append(addr)
size_list.append(size)
key_list_c = key_list[self.tp_rank %
len(key_list):] + key_list[:self.tp_rank %
len(key_list)]
addr_list_c = addr_list[self.tp_rank %
len(addr_list):] + addr_list[:self.tp_rank %
len(addr_list)]
size_list_c = size_list[self.tp_rank %
len(size_list):] + size_list[:self.tp_rank %
len(size_list)]
self.m_store.get(key_list_c, addr_list_c, size_list_c)
self.request_queue.task_done()
self.get_event.set()

View File

@@ -0,0 +1,318 @@
from typing import Any, Optional
import vllm.envs as envs
import zmq
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import \
KVConnectorMetadata
from vllm.utils import logger
from vllm.utils.network_utils import make_zmq_socket
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.kv_cache_utils import BlockHash
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.request import Request
from vllm.v1.serial_utils import MsgpackEncoder
from vllm_ascend.distributed.kvpool.config_data import (
AscendConnectorMetadata, LoadSpec, ReqMeta, RequestTracker)
class KVPoolScheduler:
def __init__(self, vllm_config: "VllmConfig", use_layerwise):
self.client = LookupKeyClient(vllm_config)
self.use_layerwise = use_layerwise
self.kv_role = vllm_config.kv_transfer_config.kv_role
self.consumer_is_to_load = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"consumer_is_to_load", False)
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"load_async", False)
# request_id -> (vllm cached tokes, kvpool cached tokens)
self.load_specs: dict[str, LoadSpec] = {}
self._block_size = vllm_config.cache_config.block_size
# request_id -> full_token_ids
self._request_trackers: dict[str, RequestTracker] = {}
# Whether to discard partial chunks
self._discard_partial_chunks = (
vllm_config.kv_transfer_config.get_from_extra_config(
"discard_partial_chunks", True))
self._unfinished_requests: dict[str, tuple[Request, list[int]]] = {}
self._unfinished_request_ids: set[str] = set()
def get_num_new_matched_tokens(
self,
request: "Request",
num_computed_tokens: int,
) -> tuple[int, bool]:
"""
Check for external KV cache hit.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
"""
if self.kv_role == "kv_consumer" and not self.consumer_is_to_load:
return 0, False
if self._discard_partial_chunks:
token_len = len(request.prompt_token_ids
) // self._block_size * self._block_size
else:
token_len = len(request.prompt_token_ids)
num_external_hit_tokens = self.client.lookup(token_len,
request.block_hashes)
if num_external_hit_tokens == request.num_tokens:
num_external_hit_tokens -= 1
need_to_allocate = num_external_hit_tokens - num_computed_tokens
logger.info(
"Reqid: %s, Total tokens %d, kvpool hit tokens: %d, need to load: %d",
request.request_id,
request.num_tokens,
num_external_hit_tokens,
need_to_allocate,
)
if need_to_allocate <= 0:
return 0, False
self.load_specs[request.request_id] = LoadSpec(
vllm_cached_tokens=num_computed_tokens,
kvpool_cached_tokens=num_external_hit_tokens,
can_load=False,
)
return need_to_allocate, self.load_async and not self.use_layerwise
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
"""
Update KVConnector state after temporary buffer alloc.
For SharedStorageConnector, update _request_needs_load
if the CacheManager this allocated blocks for us.
"""
local_block_ids = []
if num_external_tokens > 0:
local_block_ids = blocks.get_block_ids()[0]
self._unfinished_requests[request.request_id] = (request,
local_block_ids)
self._unfinished_request_ids.add(request.request_id)
if request.request_id not in self.load_specs:
# No KV tokens from external KV cache, return
return
if num_external_tokens == 0:
# No need to load anything
self.load_specs[request.request_id].can_load = False
return
assert (
num_external_tokens > 0 and num_external_tokens
== self.load_specs[request.request_id].kvpool_cached_tokens -
self.load_specs[request.request_id].vllm_cached_tokens
), (f"Mismatch in number of tokens: {num_external_tokens} vs "
f"{self.load_specs[request.request_id].kvpool_cached_tokens} - "
f"{self.load_specs[request.request_id].vllm_cached_tokens}"
f" for request {request.request_id}")
self.load_specs[request.request_id].can_load = True
def build_connector_meta(
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
"""Attach the connector metadata to the request object.
This function should NOT modify other fields in the scheduler_output
except the `kv_connector_metadata` field.
Also, calling this function will reset the state of the connector.
Args:
scheduler_output (SchedulerOutput): the scheduler output object.
"""
force_skip_save = self.kv_role == "kv_consumer"
for finished_req_id in scheduler_output.finished_req_ids:
self._request_trackers.pop(finished_req_id, None)
self._unfinished_requests.pop(finished_req_id, None)
self._unfinished_request_ids.discard(finished_req_id)
meta = AscendConnectorMetadata(self._unfinished_request_ids)
for request in scheduler_output.scheduled_new_reqs:
# Right now, we only load KV for new requests
load_spec = self.load_specs.pop(request.req_id, None)
num_tokens_to_compute = (
request.num_computed_tokens +
scheduler_output.num_scheduled_tokens[request.req_id])
request_tracker = RequestTracker.from_new_request(
request, num_tokens_to_compute)
self._request_trackers[request.req_id] = request_tracker
last_chunk_tokens_num = ((len(request.prompt_token_ids) //
self._block_size * self._block_size)
if self._discard_partial_chunks else len(
request.prompt_token_ids))
request_tuple = self._unfinished_requests.get(request.req_id)
request_real = request_tuple[0] # type: ignore[index]
req_meta = ReqMeta.from_request_tracker(
request_tracker,
self._block_size,
load_spec=load_spec,
skip_save=force_skip_save,
block_hashes=request_real.block_hashes,
is_last_chunk=request_tracker.token_len
>= last_chunk_tokens_num,
discard_partial_chunks=self._discard_partial_chunks,
)
if req_meta is not None:
meta.add_request(req_meta)
cached_reqs = scheduler_output.scheduled_cached_reqs
if not force_skip_save:
for i, req_id in enumerate(cached_reqs.req_ids):
request_tracker = self._request_trackers[req_id]
num_new_tokens = scheduler_output.num_scheduled_tokens[req_id]
req_tuple = self._unfinished_requests.get(req_id)
if req_tuple:
request = req_tuple[0]
num_current_tokens = request_tracker.token_len
new_token_ids = request.all_token_ids[
num_current_tokens:num_current_tokens + num_new_tokens]
else:
raise ValueError(
f"Request {req_id} is not in _unfinished_requests, "
f"but it is scheduled to be cached")
new_block_ids = cached_reqs.new_block_ids[i]
if not new_block_ids:
continue
request_tracker.update(new_token_ids, new_block_ids)
# decode not save
if request_tracker.token_len > len(request.prompt_token_ids):
continue
last_chunk_tokens_num = ((len(request.prompt_token_ids) //
self._block_size * self._block_size)
if self._discard_partial_chunks else
len(request.prompt_token_ids))
req_meta = ReqMeta.from_request_tracker(
request_tracker,
self._block_size,
load_spec=None,
skip_save=force_skip_save,
block_hashes=request.block_hashes,
is_last_chunk=request_tracker.token_len
>= last_chunk_tokens_num,
discard_partial_chunks=self._discard_partial_chunks,
)
if req_meta is not None:
meta.add_request(req_meta)
request_ids = [
req.req_id for req in scheduler_output.scheduled_new_reqs
]
for request_id, (request,
block_ids) in self._unfinished_requests.items():
if request_id not in request_ids and request_id not in cached_reqs.req_ids:
load_spec = self.load_specs.pop(request_id, None)
if not load_spec:
continue
num_tokens_to_compute = load_spec.kvpool_cached_tokens
if (num_tokens_to_compute % self._block_size
!= 0) and (num_tokens_to_compute
== len(request.prompt_token_ids) - 1):
num_tokens_to_compute = num_tokens_to_compute + 1
request_tracker = RequestTracker(
req_id=request_id,
token_len=num_tokens_to_compute,
allocated_block_ids=block_ids,
num_saved_tokens=0,
)
self._request_trackers[request_id] = request_tracker
req_meta = ReqMeta.from_request_tracker(
request_tracker,
self._block_size,
load_spec=load_spec,
skip_save=None,
block_hashes=request.block_hashes,
discard_partial_chunks=self._discard_partial_chunks,
)
if req_meta is not None:
meta.add_request(req_meta)
return meta
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
"""
Once a request is finished, determine whether request blocks
should be freed now or will be sent asynchronously and freed later.
"""
if self.kv_role == "kv_consumer":
return False, None
tracker = self._request_trackers.get(request.request_id)
if tracker is not None and tracker.num_saved_tokens <= 0:
return False, None
delay_free_blocks = len(block_ids) > 0
if delay_free_blocks:
logger.info("Delaying free of %d blocks for request %s",
len(block_ids), request.request_id)
return delay_free_blocks, None
class LookupKeyClient:
def __init__(self, vllm_config: "VllmConfig"):
self.encoder = MsgpackEncoder()
self.ctx = zmq.Context() # type: ignore[attr-defined]
socket_path = get_zmq_rpc_path_lookup(vllm_config)
self.socket = make_zmq_socket(
self.ctx,
socket_path,
zmq.REQ, # type: ignore[attr-defined]
bind=False,
)
def lookup(self, token_len: int, block_hashes: list[BlockHash]) -> int:
hash_strs = [h.hex() for h in block_hashes]
hash_frames = self.encoder.encode(hash_strs)
token_len_bytes = token_len.to_bytes(4, byteorder="big")
all_frames = [token_len_bytes] + list(hash_frames)
self.socket.send_multipart(all_frames, copy=False)
resp = self.socket.recv()
result = int.from_bytes(resp, "big")
return result
def close(self):
self.socket.close(linger=0)
def get_zmq_rpc_path_lookup(
vllm_config: Optional["VllmConfig"] = None, ) -> str:
base_url = envs.VLLM_RPC_BASE_PATH
# Default to 0 if not configured
rpc_port = 0
if vllm_config is not None:
extra_config = vllm_config.kv_transfer_config.kv_connector_extra_config
if "lookup_rpc_port" in extra_config:
rpc_port = extra_config["lookup_rpc_port"]
elif "mooncake_rpc_port" in extra_config:
rpc_port = extra_config["mooncake_rpc_port"]
logger.warning(
"It is recommended to use the lookup_rpc_port, as the mooncake_rpc_port will be removed in the future."
)
logger.debug("Base URL: %s, RPC Port: %s", base_url, rpc_port)
return f"ipc://{base_url}/lookup_rpc_port_{rpc_port}"

View File

@@ -0,0 +1,578 @@
# Standard
import math
import threading
from typing import Dict, Generator, Optional, Type
# Third Party
import torch
from vllm.config import VllmConfig
from vllm.utils import logger
from vllm.v1.core.kv_cache_utils import BlockHash
from vllm_ascend.distributed.kvpool.backend.backend import Backend
from vllm_ascend.distributed.kvpool.backend.memcache_backend import \
MemcacheBackend
from vllm_ascend.distributed.kvpool.backend.mooncake_backend import \
MooncakeBackend
from vllm_ascend.distributed.kvpool.config_data import (
AscendConnectorMetadata, ChunkedTokenDatabase, KeyMetadata,
LasyerMultiBlockReqMeta)
from vllm_ascend.distributed.kvpool.kv_transfer import (
KVCacheStoreLayerRecvingThread, KVCacheStoreLayerSendingThread,
KVCacheStoreRecvingThread, KVCacheStoreSendingThread, KVTransferThread)
backend_map: Dict[str, Type[Backend]] = {
"mooncake": MooncakeBackend,
"memcache": MemcacheBackend,
}
class KVPoolWorker:
#The main class for the cache engine.
def __init__(
self,
vllm_config: VllmConfig,
use_layerwize: bool,
):
model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
self.dp_rank = parallel_config.data_parallel_rank
self.use_mla = False
if (hasattr(model_config, "use_mla")
and isinstance(model_config.use_mla, bool)
and model_config.use_mla):
self.use_mla = True
self.use_layerwise = use_layerwize
self.tp_rank = parallel_config.rank
self.tp_size = parallel_config.tensor_parallel_size
self.kv_role = vllm_config.kv_transfer_config.kv_role
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"load_async", False)
self.backend = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"backend", "mooncake")
self.block_size = vllm_config.cache_config.block_size
self.current_layer = 0
self.num_layers = model_config.get_num_layers(parallel_config)
self.block_size = vllm_config.cache_config.block_size
if self.use_mla:
self.num_kv_head = 1
else:
self.num_kv_head = model_config.get_total_num_kv_heads()
if self.num_kv_head < self.tp_size:
self.put_step = self.tp_size // self.num_kv_head
self.head_or_tp_rank = self.tp_rank // self.put_step
else:
self.head_or_tp_rank = self.tp_rank
self.put_step = 1
self.metadata = KeyMetadata(
model_config.model,
self.head_or_tp_rank,
)
self.token_database = ChunkedTokenDatabase(self.metadata,
self.block_size,
self.use_mla)
real_backend = backend_map.get(self.backend.lower())
self.m_store = real_backend( # type: ignore[misc]
parallel_config)
self.kv_send_thread: Optional[KVTransferThread] = None
self.kv_recv_thread: Optional[KVTransferThread] = None
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
_, first_kv_cache_tuple = next(iter(kv_caches.items()))
first_kv_cache = first_kv_cache_tuple[0]
# TODO(tms): Find a more robust way to detect and handle MLA
if self.use_mla:
# MLA case.[num_block, block_size, 1, hidden_dim]
self.num_blocks = first_kv_cache.shape[0]
block_rank = 3 # [block_size, latent_dim]
block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:]
block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:]
self.block_len = [
first_kv_cache[0].element_size() * math.prod(block_shape_norm),
first_kv_cache[1].element_size() * math.prod(block_shape_pe)
]
logger.info(
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s",
self.num_blocks, block_shape_norm, block_shape_pe)
else:
# [num_block, block_size, num_head, hidden_dim]
self.num_blocks = first_kv_cache.shape[0]
kv_elem_size = first_kv_cache.element_size()
block_rank = 3 # [block_size, kv_heads, head_dim]
block_shape = first_kv_cache.shape[-block_rank:]
self.block_len = [kv_elem_size * math.prod(block_shape)]
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks,
block_shape)
logger.info("Registering KV_Caches. use_mla: %s, shape %s",
self.use_mla, first_kv_cache.shape)
self.kv_caches = kv_caches
self.kv_caches_base_addr = []
ptrs = []
lengths = []
for cache_or_caches in kv_caches.values():
# Normalize to always be a list of caches
if self.use_mla:
for i, cache in enumerate(cache_or_caches, 0):
base_addr = cache.data_ptr()
self.kv_caches_base_addr.append(base_addr)
region_len = self.num_blocks * self.block_len[i % 2]
ptrs.append(base_addr)
lengths.append(region_len)
else:
cache_list = [cache_or_caches
] if self.use_mla else cache_or_caches
for cache in cache_list:
base_addr = cache.data_ptr()
self.kv_caches_base_addr.append(base_addr)
region_len = self.num_blocks * self.block_len[0]
ptrs.append(base_addr)
lengths.append(region_len)
self.m_store.register_buffer(ptrs, lengths)
self.token_database.set_kv_caches_base_addr(self.kv_caches_base_addr)
self.token_database.set_block_len(self.block_len)
if self.use_layerwise:
self.get_event = threading.Event()
if self.kv_role in ['kv_producer', 'kv_both']:
ready_event_sending = threading.Event()
self.kv_send_thread = KVCacheStoreLayerSendingThread(
self.m_store, self.token_database, self.tp_rank,
self.put_step, ready_event_sending, self.num_layers)
self.kv_send_thread.start()
ready_event = threading.Event()
self.kv_recv_thread = KVCacheStoreLayerRecvingThread(
self.m_store, self.token_database, self.tp_rank, ready_event,
self.get_event)
self.kv_recv_thread.start()
ready_event.wait()
else:
if self.kv_role in ['kv_producer', 'kv_both']:
ready_event_sending = threading.Event()
self.kv_send_thread = KVCacheStoreSendingThread(
self.m_store, self.token_database, self.tp_rank,
self.put_step, ready_event_sending)
self.kv_send_thread.start()
if self.load_async:
ready_event = threading.Event()
self.kv_recv_thread = KVCacheStoreRecvingThread(
self.m_store, self.token_database, self.tp_rank,
ready_event)
self.kv_recv_thread.start()
ready_event.wait()
def start_load_kv(self, metadata: AscendConnectorMetadata):
self.current_layer = 0
self.layerwise_retrievers = []
for request in metadata.requests:
load_spec = request.load_spec
if load_spec is None or not load_spec.can_load: #load =0
continue
token_len = request.token_len_chunk
req_id = request.req_id
if (load_spec.kvpool_cached_tokens % self.block_size
!= 0) and (load_spec.kvpool_cached_tokens
== token_len - 1):
token_len = request.load_spec.kvpool_cached_tokens + 1
else:
token_len = request.load_spec.kvpool_cached_tokens
mask_num = (request.load_spec.vllm_cached_tokens //
self.block_size * self.block_size)
if self.use_layerwise:
layerwise_retriever = self.retrieve_layer(
req_id,
token_len,
request.block_ids,
request.block_hashes,
mask_num,
)
next(layerwise_retriever) # first layer load
self.layerwise_retrievers.append(layerwise_retriever)
else:
if self.load_async:
self.kv_recv_thread.add_request( # type: ignore[union-attr]
req_id,
token_len,
request.block_ids,
request.block_hashes,
mask_num,
)
else:
addr_list = []
size_list = []
key_list = []
for start, end, key in self.token_database.process_tokens(
token_len, request.block_hashes, mask_num):
addr, size, _ = self.token_database.prepare_value(
start, end, request.block_ids)
key_list.append(key.to_string())
addr_list.append(addr)
size_list.append(size)
key_list_c = key_list[self.tp_rank % len(
key_list):] + key_list[:self.tp_rank % len(key_list)]
addr_list_c = addr_list[self.tp_rank %
len(addr_list
):] + addr_list[:self.tp_rank %
len(addr_list)]
size_list_c = size_list[self.tp_rank %
len(size_list
):] + size_list[:self.tp_rank %
len(size_list)]
self.m_store.get(key_list_c, addr_list_c, size_list_c)
def wait_for_layer_load(self) -> None:
for layerwise_retriever in self.layerwise_retrievers:
ret_token_mask = next(layerwise_retriever)
if self.current_layer == self.num_layers - 1:
assert ret_token_mask is not None
num_retrieved_tokens = ret_token_mask.sum().item()
logger.debug(f"Retrieved {num_retrieved_tokens} tokens")
def save_kv_layer(self,
connector_metadata: AscendConnectorMetadata) -> None:
if self.current_layer == 0:
self.layerwise_storers = []
for request in connector_metadata.requests:
can_save = request.can_save
if can_save is None or not can_save:
continue
token_len = request.token_len_chunk
req_id = request.req_id
# TODO: whether need to remov saveThread
# no lookup, skipmask
skip_leading_tokens = self.lookup(token_len,
request.block_hashes,
self.use_layerwise)
if skip_leading_tokens == token_len:
if request.is_last_chunk:
self.kv_send_thread.set_finished_request( # type: ignore[union-attr]
req_id)
continue # skip this request
mask_num = (skip_leading_tokens // self.block_size *
self.block_size)
logger.info(
"Storing KV cache for %d out of %d tokens "
"(skip_leading_tokens=%d) for request %s",
token_len - skip_leading_tokens,
token_len,
skip_leading_tokens,
request.req_id,
)
layerwise_storer = self.store_layer(
req_id,
token_len,
block_hashes=request.block_hashes,
mask_num=mask_num,
block_ids=request.block_ids,
is_last_chunk=request.is_last_chunk,
)
self.layerwise_storers.append(layerwise_storer)
for layerwise_storer in self.layerwise_storers:
try:
next(layerwise_storer)
except Exception:
raise
self.current_layer = self.current_layer + 1
def wait_for_save(self, connector_metadata: AscendConnectorMetadata):
for request in connector_metadata.requests:
can_save = request.can_save
if can_save is None or not can_save:
continue
token_len = request.token_len_chunk
req_id = request.req_id
skip_leading_tokens = self.lookup(token_len, request.block_hashes,
self.use_layerwise)
if skip_leading_tokens == token_len:
if request.is_last_chunk:
self.kv_send_thread.set_finished_request( # type: ignore[union-attr]
req_id)
continue # skip this request
mask_num = (skip_leading_tokens // self.block_size *
self.block_size)
logger.info(
"Storing KV cache for %d out of %d tokens "
"(skip_leading_tokens=%d) for request %s",
token_len - skip_leading_tokens,
token_len,
skip_leading_tokens,
request.req_id,
)
self.kv_send_thread.add_request( # type: ignore[union-attr]
req_id,
token_len,
request.block_ids,
request.block_hashes,
mask_num,
request.is_last_chunk,
)
def retrieve_layer(
self,
req_id: str,
token_len: int,
block_ids: list[int],
block_hashes: list[BlockHash],
mask_num: int = 0,
) -> Generator[Optional[torch.Tensor], None, None]:
"""
Retrieve the KV cache in a layerwise manner.
:param torch.Tensor tokens: The tokens of the corresponding KV caches.
:param Optional[torch.Tensor] mask: The mask for the tokens. Should
have the same length as tokens. And the mask should ALWAYS be like
FFFFFTTTTTTT, where True means the tokens needs to be matched.
:param **kwargs: The additional arguments for the KV transfer which
will be passed into the npu_transfer.
return: A generator that yields Optional[torch.Tensor]. The tensor will
be the boolean mask indicating which tokens are retrieved and will
only be returned in the last iteration.
"""
num_required_tokens = token_len - mask_num
ret_mask = torch.zeros(token_len, dtype=torch.bool, device="cpu")
starts = []
ends = []
keys = []
first_flag = True
for start, end, key in self.token_database.process_tokens(
token_len, block_hashes, mask_num):
keys_multi_layer = key.split_layers(self.num_layers)
starts.append(start)
ends.append(end)
keys.append(keys_multi_layer)
ret_mask[start:end] = True
if keys:
# Transpose the keys into layer major format
keys = [list(row) for row in zip(*keys)] # [num_layer,block_num]
for layer_id, keys_multi_chunk in enumerate(keys):
if not first_flag:
is_finish = self.get_event.wait(timeout=3) #try---cache
if not is_finish:
logger.info("Layerwise get failed")
self.get_event.clear()
req_meta = LasyerMultiBlockReqMeta(req_id, keys_multi_chunk,
starts, ends, block_ids,
layer_id)
self.kv_recv_thread.add_request( # type: ignore[union-attr, call-arg]
req_meta) # type: ignore[union-attr, call-arg, arg-type]
first_flag = False
yield None
else:
# If no cache are found, we still need to yield to avoid
# `StopIteration`
for layer_id in range(self.num_layers):
yield None
retrieved_tokens = torch.sum(ret_mask)
logger.debug(f"Retrieved {retrieved_tokens} "
f"out of {num_required_tokens} "
f"out of total {token_len} tokens")
yield ret_mask
def store_layer(
self,
req_id: str,
token_len: int,
block_ids: list[int],
block_hashes: list[BlockHash],
is_last_chunk: bool,
mask_num: int = 0,
) -> Generator[None, None, None]:
"""
Store the KV cache in a layerwise manner.
:param torch.Tensor tokens: The tokens of the corresponding KV caches.
:param Optional[torch.Tensor] mask: The mask for the tokens. Should
have the same length as tokens. And the mask should ALWAYS be like
FFFFFTTTTTTT, where True means the tokens needs to be matched.
:param **kwargs: The additional arguments for the storage backend which
will be passed into the gpu_connector.
return: A generator that yields None. In the first iteration, the
generator allocates the memory objects for all layers and moves
the KV cache of the first layer from GPU to CPU. In the next
iterations, it moves the KV cache of layer i from GPU to the memory
objects (on CPU) and puts the memory objects of layer i-1 to the
storage backends. In the last iteration, it puts the memory objects
of the last layer to the storage backends.
"""
num_stored_tokens = token_len - mask_num
starts = []
ends = []
keys = []
for start, end, key in self.token_database.process_tokens(
token_len, block_hashes, mask_num):
keys_multi_layer = key.split_layers(self.num_layers)
starts.append(start)
ends.append(end)
keys.append(keys_multi_layer) #[block_num,layer_num]
if keys:
keys = [list(row) for row in zip(*keys)] #[layer_num,block_num]
for layer_id, keys_multi_chunk in enumerate(keys):
req_meta = LasyerMultiBlockReqMeta(req_id, keys_multi_chunk,
starts, ends, block_ids,
layer_id, is_last_chunk)
self.kv_send_thread.add_request( # type: ignore[union-attr, call-arg]
req_meta) # type: ignore[union-attr, call-arg, arg-type]
yield
else:
for layer_id in range(self.num_layers):
yield
logger.debug(
f"Stored {num_stored_tokens} out of total {token_len} tokens")
def get_finished(self) -> tuple[set[str], set[str]]:
done_sending = (
self.kv_send_thread.
get_and_clear_finished_requests( # type: ignore[union-attr]
) if self.kv_role in ['kv_producer', 'kv_both'] else set())
done_recving = (
self.kv_recv_thread.
get_and_clear_finished_requests( # type: ignore[union-attr]
) if self.load_async else set())
logger.debug(
"Number of completed KV cache send requests: %d, receive "
"requests: %d, tp_rank:%d", len(done_sending), len(done_recving),
self.tp_rank)
return done_sending, done_recving
def lookup(
self,
token_len: int,
block_hashes: list[BlockHash],
use_layerwise: bool,
) -> int:
"""
Checks the existence of KV cache of the tokens from the cache engine.
:param tokens: the input tokens, with shape [seq_len]
:return: An int indicating how many prefix tokens are cached.
"""
end = 0
keys = []
try:
starts = []
for start, end, key in self.token_database.process_tokens(
token_len, block_hashes):
if use_layerwise:
keys_multi_layer = key.split_layers(self.num_layers)
for item in keys_multi_layer:
keys.append(item.to_string())
else:
keys.append(key.to_string())
starts.append(start)
res = self.m_store.exists(keys) # type: ignore[assignment]
if use_layerwise:
res = self.check_all_layers_exists(res, self.num_layers)
for index, value in enumerate(res): # type: ignore[arg-type]
if value != 1:
return starts[index]
# all tokens where found, return the maximal end
except Exception as e:
logger.error(f"Remote connection failed in contains: {e}")
return start
return end
def lookup_scheduler(
self,
token_len: int,
block_hashes: list[BlockHash],
use_layerwise: bool,
) -> int:
"""
Checks the existence of KV cache of the tokens from the cache engine.
:param tokens: the input tokens, with shape [seq_len]
:return: An int indicating how many prefix tokens are cached.
"""
end = 0
keys = []
try:
starts = []
for start, end, key in self.token_database.process_tokens(
token_len, block_hashes):
if use_layerwise:
keys_multi_layer = key.split_layers(self.num_layers)
for item in keys_multi_layer:
keys.append(item.to_string())
else:
keys.append(key.to_string())
starts.append(start)
multi_tp_keys = keys[:]
for i in range(1, min(self.tp_size, self.num_kv_head)):
for item in keys:
new_str = item.replace( # type: ignore[attr-defined]
"@head_or_tp_rank:0", f"@head_or_tp_rank:{i}", 1)
multi_tp_keys.append(new_str)
res = self.m_store.exists(
multi_tp_keys) # type: ignore[assignment]
num_block = len(keys)
if use_layerwise:
res = self.check_all_layers_exists(res, self.num_layers)
num_block = len(keys) // self.num_layers
multi_tp_values = [
res[i * num_block:(i + 1) * num_block] # type: ignore[index]
for i in range(min(self.tp_size, self.num_kv_head))
]
index = self.find_min_first_non_one_index(multi_tp_values)
if index != -1:
return starts[index]
# all tokens where found, return the maximal end
except Exception as e:
logger.error(f"Remote connection failed in contains: {e}")
return start
return end
def check_all_layers_exists(self, res: list[int],
num_layers: int) -> list[int]:
total_chunks = len(res) // num_layers
result = []
for chunk_idx in range(total_chunks):
start = chunk_idx * num_layers
end = start + num_layers
chunk = res[start:end]
result.append(1 if all(x == 1 for x in chunk) else 0)
return result
def find_min_first_non_one_index(self, arr):
try:
return min(idx for row in arr for idx, val in enumerate(row)
if val != 1)
except ValueError:
return -1