[Refactor]Refactor of vllm_ascend/distributed module (#5719)

### What this PR does / why we need it?
Based on the RFC:https://github.com/vllm-project/vllm-ascend/issues/5604

This PR is a refactoring of vllm_ascend/distributed, moving all
kv_transfer realtaed codes into a dedicated folder, which has already
been done in vLLM

### Does this PR introduce _any_ user-facing change?
NA

### How was this patch tested?


- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef

---------

Signed-off-by: lty <linhebiwen@gmail.com>
This commit is contained in:
lty
2026-01-15 08:57:40 +08:00
committed by GitHub
parent f34b3b8ee9
commit 295018ec0f
56 changed files with 300 additions and 293 deletions

View File

@@ -0,0 +1,45 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from vllm.distributed.kv_transfer.kv_connector.factory import \
KVConnectorFactory
def register_connector():
KVConnectorFactory.register_connector(
"MooncakeConnectorV1",
"vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector",
"MooncakeConnector")
KVConnectorFactory.register_connector(
"MooncakeConnectorStoreV1",
"vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.ascend_store_connector",
"AscendStoreConnector")
KVConnectorFactory.register_connector(
"AscendStoreConnector",
"vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.ascend_store_connector",
"AscendStoreConnector")
KVConnectorFactory.register_connector(
"MooncakeLayerwiseConnector",
"vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector",
"MooncakeLayerwiseConnector")
KVConnectorFactory.register_connector(
"UCMConnector", "vllm_ascend.distributed.kv_transfer.ucm_connector",
"UCMConnectorV1")

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,184 @@
import threading
from typing import Any, Optional
import torch
import zmq
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.forward_context import ForwardContext
from vllm.logger import logger
from vllm.utils.network_utils import make_zmq_socket
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request
from vllm.v1.serial_utils import MsgpackDecoder
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.pool_scheduler import (
KVPoolScheduler, get_zmq_rpc_path_lookup)
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.pool_worker import \
KVPoolWorker
class AscendStoreConnector(KVConnectorBase_V1):
def __init__(self,
vllm_config: VllmConfig,
role: KVConnectorRole,
kv_cache_config: Optional[KVCacheConfig] = None):
super().__init__(vllm_config=vllm_config,
role=role,
kv_cache_config=kv_cache_config)
self.kv_role = vllm_config.kv_transfer_config.kv_role
self.use_layerwise = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"use_layerwise", False)
self.consumer_is_to_put = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"consumer_is_to_put", False)
connector_name = vllm_config.kv_transfer_config.kv_connector
if connector_name == "MooncakeConnectorStoreV1":
logger.warning(
"It is recommended to use the AscendStoreConnector, as the MoonCakeStoreConnector will be removed in the future."
)
self.kv_caches: dict[str, torch.Tensor] = {}
self.sended_but_unfinished_reqs: set[str] = set()
if role == KVConnectorRole.SCHEDULER:
self.connector_scheduler = KVPoolScheduler(vllm_config,
self.use_layerwise)
else:
self.connector_worker = KVPoolWorker(
vllm_config,
self.use_layerwise,
)
assert self.connector_worker is not None
if vllm_config.parallel_config.rank == 0:
self.lookup_server = LookupKeyServer(self.connector_worker,
vllm_config,
self.use_layerwise)
############################################################
# Scheduler Side Methods
############################################################
def get_num_new_matched_tokens(
self, request: "Request",
num_computed_tokens: int) -> tuple[int, bool]:
assert self.connector_scheduler is not None
return self.connector_scheduler.get_num_new_matched_tokens(
request, num_computed_tokens)
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
assert self.connector_scheduler is not None
return self.connector_scheduler.update_state_after_alloc(
request, blocks, num_external_tokens)
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
assert self.connector_scheduler is not None
return self.connector_scheduler.build_connector_meta(scheduler_output)
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
assert self.connector_scheduler is not None
return self.connector_scheduler.request_finished(request, block_ids)
############################################################
# Worker Side Methods
############################################################
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
assert self.connector_worker is not None
self.connector_worker.register_kv_caches(kv_caches)
def start_load_kv(self, forward_context: "ForwardContext",
**kwargs) -> None:
assert self.connector_worker is not None
self.connector_worker.start_load_kv(self._get_connector_metadata())
def wait_for_layer_load(self, layer_name: str) -> None:
if not self.use_layerwise:
return
self.connector_worker.wait_for_layer_load()
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", **kwargs) -> None:
if not self.use_layerwise:
return
if self.kv_role == "kv_consumer":
# Don't do save if the role is kv_consumer
return
self.connector_worker.save_kv_layer(self._get_connector_metadata())
def wait_for_save(self):
if self.kv_role == "kv_consumer" and not self.consumer_is_to_put:
# Don't do save if the role is kv_consumer
return
if self.use_layerwise:
return
self.connector_worker.wait_for_save(self._get_connector_metadata())
def get_finished(self,
finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
"""Get the finished recving and sending requests."""
assert self.connector_worker is not None
done_sending, done_recving = self.connector_worker.get_finished(
finished_req_ids)
return done_sending, done_recving
class LookupKeyServer:
def __init__(
self,
pool_worker: KVPoolWorker,
vllm_config: "VllmConfig",
use_layerwise: bool,
):
self.decoder = MsgpackDecoder()
self.decoder_tensor = MsgpackDecoder(torch.Tensor)
self.ctx = zmq.Context() # type: ignore[attr-defined]
socket_path = get_zmq_rpc_path_lookup(vllm_config)
self.socket = make_zmq_socket(
self.ctx,
socket_path,
zmq.REP, # type: ignore[attr-defined]
bind=True,
)
self.pool_worker = pool_worker
self.running = True
self.use_layerwise = use_layerwise
def process_request():
while self.running:
all_frames = self.socket.recv_multipart(copy=False)
token_len = int.from_bytes(all_frames[0], byteorder="big")
hash_frames = all_frames[1:]
hashes_str = self.decoder.decode(hash_frames)
result = self.pool_worker.lookup_scheduler(
token_len, hashes_str, self.use_layerwise)
response = result.to_bytes(4, "big")
self.socket.send(response)
self.thread = threading.Thread(target=process_request, daemon=True)
self.thread.start()
def close(self):
self.socket.close(linger=0)
# TODO: close the thread!

View File

@@ -0,0 +1,29 @@
from abc import ABC, abstractmethod
from vllm.config import ParallelConfig
class Backend(ABC):
def __init__(self, parallel_config: ParallelConfig):
pass
def set_device(self):
pass
def register_buffer(self, ptrs: list[int], lengths: list[int]):
pass
@abstractmethod
def exists(self, keys: list[str]) -> list[int]:
pass
@abstractmethod
def put(self, keys: list[str], addrs: list[list[int]],
sizes: list[list[int]]):
pass
@abstractmethod
def get(self, keys: list[str], addrs: list[list[int]],
sizes: list[list[int]]):
pass

View File

@@ -0,0 +1,96 @@
# Standard
from enum import Enum
import torch
from vllm.config import ParallelConfig
from vllm.logger import logger
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import \
Backend
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
class MmcDirect(Enum):
COPY_L2G = 0
COPY_G2L = 1
COPY_G2H = 2
COPY_H2G = 3
class MemcacheBackend(Backend):
def __init__(self, parallel_config: ParallelConfig):
try:
from memcache_hybrid import DistributedObjectStore # type: ignore
except ImportError as e:
raise ImportError(
"Please install memcache by following the instructions at "
"https://gitee.com/ascend/memfabric_hybrid " # noqa: E501
"to run vLLM with MemcacheConnector.") from e
try:
soc_version = get_ascend_device_type()
if soc_version in {AscendDeviceType.A2}:
import torch
from vllm.distributed import get_world_group
tmp_tensor = torch.zeros(1, device="npu")
output_tensor_list = [
torch.empty_like(tmp_tensor)
for _ in range(torch.distributed.get_world_size())
]
torch.distributed.all_gather(
output_tensor_list,
tmp_tensor,
group=get_world_group().device_group)
self.rank = parallel_config.rank
self.store = DistributedObjectStore()
res = self.store.init(self.rank)
assert res == 0
else:
self.rank = parallel_config.rank
self.store = DistributedObjectStore()
res = self.store.init(self.rank)
assert res == 0
except ValueError as e:
logger.error("Configuration loading failed: %s", e)
raise
except Exception as exc:
logger.error(
"An error occurred while loading the configuration: %s", exc)
raise
def set_device(self):
device = torch.device(f"npu:{self.rank}")
torch.npu.set_device(device)
def register_buffer(self, ptrs: list[int], sizes: list[int]):
soc_version = get_ascend_device_type()
if soc_version in {AscendDeviceType.A2}:
for ptr, size in zip(ptrs, sizes):
self.store.register_buffer(ptr, size)
else:
pass
def exists(self, keys: list[str]) -> list[int]:
return self.store.batch_is_exist(keys)
def get(self, key: list[str], addr: list[list[int]],
size: list[list[int]]):
try:
res = self.store.batch_get_into_layers(key, addr, size,
MmcDirect.COPY_G2L.value)
for value in res:
if value != 0:
logger.error(f"Failed to get key {key},res:{res}")
except Exception as e:
logger.error(f"Failed to get key {key}. {e}")
def put(self, key: list[str], addr: list[list[int]],
size: list[list[int]]):
try:
res = self.store.batch_put_from_layers(key, addr, size,
MmcDirect.COPY_L2G.value)
for value in res:
if value != 0:
logger.error(f"Failed to get key {key},res:{res}")
except Exception as e:
logger.error(f"Failed to put key {key},error:{e}")

View File

@@ -0,0 +1,192 @@
# Standard
import json
import os
import re
from dataclasses import dataclass
from typing import Union
# Third Party
from mooncake.store import ReplicateConfig # type: ignore
from vllm.config import ParallelConfig
from vllm.logger import logger
from vllm.utils.network_utils import get_ip
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import \
Backend
from vllm_ascend.distributed.kv_transfer.utils.mooncake_transfer_engine import \
global_te
DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB
DEFAULT_LOCAL_BUFFER_SIZE = 1073741824 # 1.0 GiB
class MooncakeBackend(Backend):
def __init__(self, parallel_config: ParallelConfig):
try:
from mooncake.store import MooncakeDistributedStore # type: ignore
except ImportError as e:
raise ImportError(
"Please install mooncake by following the instructions at "
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
"to run vLLM with MooncakeConnector.") from e
self.config = MooncakeStoreConfig.load_from_env()
self.store = MooncakeDistributedStore()
if self.config.protocol == "ascend":
local_hostname = get_ip()
transfer_engine = global_te.get_transfer_engine(local_hostname,
device_name=None)
self.local_seg = local_hostname + ":" + str(
transfer_engine.get_rpc_port())
ret = self.store.setup(self.local_seg, self.config.metadata_server,
self.config.global_segment_size,
self.config.local_buffer_size,
self.config.protocol,
self.config.device_name,
self.config.master_server_address,
transfer_engine.get_engine())
if ret != 0:
msg = "Initialize mooncake failed."
logger.error(msg)
raise RuntimeError(msg)
def register_buffer(self, ptrs: list[int], lengths: list[int]):
global_te.register_buffer(ptrs, lengths)
def exists(self, keys: list[str]) -> list[int]:
return self.store.batch_is_exist(keys)
def put(self, keys: list[str], addrs: list[list[int]],
sizes: list[list[int]]):
try:
config = ReplicateConfig()
config.preferred_segment = self.local_seg
config.prefer_alloc_in_same_node = True
res = self.store.batch_put_from_multi_buffers(
keys, addrs, sizes, config)
for value in res:
if value < 0:
logger.error(f"Failed to put key {keys},res:{res}")
except Exception as e:
logger.error(f"Failed to put key {keys},error:{e}")
def get(self, keys: list[str], addrs: list[list[int]],
sizes: list[list[int]]):
try:
res = self.store.batch_get_into_multi_buffers(
keys, addrs, sizes, True)
for value in res:
if value < 0:
logger.error(f"Failed to get key {keys}, res:{res}")
except Exception as e:
logger.error(f"Failed to get key {keys}, error:{e}")
@dataclass
class MooncakeStoreConfig:
metadata_server: str
global_segment_size: Union[int, str]
local_buffer_size: int
protocol: str
device_name: str
master_server_address: str
@staticmethod
def from_file(file_path: str) -> "MooncakeStoreConfig":
with open(file_path) as file:
config = json.load(file)
return MooncakeStoreConfig(
metadata_server=config.get("metadata_server"),
global_segment_size=_parse_global_segment_size(
config.get("global_segment_size",
DEFAULT_GLOBAL_SEGMENT_SIZE)),
local_buffer_size=_parse_global_segment_size(
config.get("local_buffer_size", DEFAULT_LOCAL_BUFFER_SIZE)),
protocol=config.get("protocol", "ascend"),
device_name=config.get("device_name", ""),
master_server_address=config.get("master_server_address"))
@staticmethod
def load_from_env() -> "MooncakeStoreConfig":
config_path = os.getenv("MOONCAKE_CONFIG_PATH")
if not config_path:
raise ValueError(
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set.")
return MooncakeStoreConfig.from_file(config_path)
def _parse_global_segment_size(value) -> int:
"""
Parse storage size strings with support for units: GB, MB, KB, B
Args:
value: Input value (int, str, or other convertible types)
Returns:
int: Size in bytes
Raises:
ValueError: For invalid format, missing number, or negative values
TypeError: For unsupported input types
"""
if isinstance(value, int):
return value
elif not isinstance(value, str):
try:
return int(value)
except (TypeError, ValueError) as e:
raise TypeError(
f"Unsupported type for global_segment_size: {type(value)}"
) from e
cleaned_input = value.strip().lower()
if not cleaned_input:
raise ValueError("global segment size cannot be empty.")
UNIT_MULTIPLIERS = {
'gb': 1024**3, # 1 GB = 1024^3 bytes
'mb': 1024**2, # 1 MB = 1024^2 bytes
'kb': 1024, # 1 KB = 1024 bytes
'b': 1 # 1 B = 1 byte
}
pattern = r'^\s*([\d.]+)\s*(gb|mb|kb|b)?\s*$'
match = re.match(pattern, cleaned_input)
if not match:
raise ValueError(f"Invalid format: '{value}'")
number_str = match.group(1)
unit = match.group(2) or 'b'
multiplier = UNIT_MULTIPLIERS[unit]
return _convert_to_bytes(number_str, multiplier, value)
def _convert_to_bytes(number_str: str, multiplier: int,
original_input: str) -> int:
"""
Convert numeric string to byte count
Args:
number_str: Numeric portion of input
multiplier: Unit conversion factor
original_input: Original input string (for error messages)
Returns:
int: Byte count
Raises:
ValueError: For invalid numbers or negative results
"""
try:
numeric_value = float(number_str)
except ValueError:
raise ValueError(
f"Invalid numeric value '{number_str}' in: '{original_input}'")
# Calculate byte count
try:
byte_count = int(numeric_value * multiplier)
except OverflowError:
raise ValueError(f"Storage size too large: '{original_input}'")
return byte_count

View File

@@ -0,0 +1,404 @@
from dataclasses import dataclass
from typing import Iterable, List, Optional, Tuple, Union
import torch
from vllm.distributed.kv_transfer.kv_connector.v1.base import \
KVConnectorMetadata
from vllm.logger import logger
from vllm.utils.math_utils import cdiv
from vllm.v1.core.kv_cache_utils import BlockHash
from vllm.v1.core.sched.output import NewRequestData
#Parameters related to the key
@dataclass
class KeyMetadata:
"""name of the LLM model"""
model_name: str
""" worker id when running under a distributed setting """
head_or_tp_rank: int
""" Initialize the current prefill context model parallel rank """
pcp_rank: int
""" Initialize the current decode context model parallel rank """
dcp_rank: int
""" Initialize the current pipeline parallel rank """
pp_rank: int
@dataclass(order=True)
class PoolKey:
key_metadata: KeyMetadata
chunk_hash: str
def __hash__(self):
return hash((
self.key_metadata.model_name,
self.key_metadata.head_or_tp_rank,
self.key_metadata.pcp_rank,
self.key_metadata.dcp_rank,
self.key_metadata.pp_rank,
self.chunk_hash,
))
def to_string(self):
return (
f"{self.key_metadata.model_name}"
f"@pcp{self.key_metadata.pcp_rank}@dcp{self.key_metadata.dcp_rank}"
f"@head_or_tp_rank:{self.key_metadata.head_or_tp_rank}"
f"@pp_rank:{self.key_metadata.pp_rank}@{self.chunk_hash}")
def split_layers(self, num_layers: int) -> List["LayerPoolKey"]:
"""Split the key into multiple keys for each layer"""
keys = []
for layer_id in range(num_layers):
keys.append(
LayerPoolKey(
self.key_metadata,
self.chunk_hash,
layer_id,
))
return keys
@dataclass(order=True)
class LayerPoolKey(PoolKey):
"""A key for the layer cache engine"""
layer_id: int
def __hash__(self):
return hash((
self.key_metadata.model_name,
self.key_metadata.head_or_tp_rank,
self.key_metadata.pcp_rank,
self.key_metadata.dcp_rank,
self.chunk_hash,
self.layer_id,
))
def to_string(self):
return (
f"{self.key_metadata.model_name}"
f"@pcp{self.key_metadata.pcp_rank}@dcp{self.key_metadata.dcp_rank}"
f"@head_or_tp_rank:{self.key_metadata.head_or_tp_rank}@{self.chunk_hash}@{self.layer_id}"
)
class ChunkedTokenDatabase():
def __init__(self, metadata: KeyMetadata, block_size: int, use_mla: bool,
partitions: Optional[List[int]]):
self.metadata = metadata
self.block_size = block_size
self.use_mla = use_mla
self.kv_caches_base_addr: list[int] = []
self.block_len: list[int] = []
self.partitions = partitions
def _make_key_by_hash(self,
chunk_hash: str,
layer_id: Optional[int] = None):
assert self.metadata is not None
return PoolKey(
self.metadata,
chunk_hash,
)
def set_kv_caches_base_addr(self, kv_caches_base_addr: list[int]):
self.kv_caches_base_addr = kv_caches_base_addr
def set_block_len(self, block_len: list[int]):
self.block_len = block_len
def prepare_value(self, start: int, end: int, block_ids: list[int]):
addr_list = []
size_list = []
block_id = block_ids[start // self.block_size]
for index, base_addr in enumerate(self.kv_caches_base_addr):
block_len = (self.block_len[index % 2]
if self.use_mla else self.block_len[0])
addr = base_addr + block_id * block_len
length = int(block_len / self.block_size * (end - start))
addr_list.append(addr)
size_list.append(length)
return addr_list, size_list, block_id
def prepare_value_layer(self, start: int, end: int, block_ids: list[int],
layer_id: int):
block_id = block_ids[start // self.block_size]
if self.use_mla:
addr_k = self.kv_caches_base_addr[layer_id *
2] + block_id * self.block_len[0]
addr_v = self.kv_caches_base_addr[layer_id * 2 +
1] + block_id * self.block_len[1]
length_k = int(self.block_len[0] / self.block_size * (end - start))
length_v = int(self.block_len[1] / self.block_size * (end - start))
size_list = [length_k, length_v]
else:
addr_k = self.kv_caches_base_addr[layer_id *
2] + block_id * self.block_len[0]
addr_v = self.kv_caches_base_addr[layer_id * 2 +
1] + block_id * self.block_len[0]
length = int(self.block_len[0] / self.block_size * (end - start))
size_list = [length, length]
addr_list = [addr_k, addr_v]
return addr_list, size_list
def process_tokens(
self,
token_len: int,
block_hashes: Union[list[BlockHash], list[str]],
mask_num: int = 0,
) -> Iterable[Tuple[int, int, PoolKey]]:
"""Process the tokens and return the corresponding cache engine keys.
:param Union[torch.Tensor, List[int]] tokens: The tokens to process.
:param Optional[torch.Tensor] mask: The mask for the tokens. Should
have the same length as tokens. And the mask should ALWAYS be like
FFFFFTTTTTTT, where True means the tokens needs to be matched,
and the Falses will ALWAYS be at the PREFIX of the tensor.
:param bool make_key: Whether to make the cache engine key or not.
If False, the hash value will be returned instead.
:returns: A iterable of tuples with three elements. The first element
is the start index of the tokens for the key. The second element
is the end index of the tokens for the key. The third element is
the cache engine key (or hash) for the tokens.
:raises: ValueError if the number of Falses in the mask is not a
multiple of the chunk size.
"""
if not block_hashes:
return
if not isinstance(block_hashes[0], str):
block_hashes = [
h.hex() # type: ignore[union-attr]
for h in block_hashes
]
start_idx = 0
for chunk_id, hash_val in enumerate(block_hashes):
start_idx = chunk_id * self.block_size
if start_idx >= token_len:
break
end_idx = min(start_idx + self.block_size, token_len)
if start_idx < mask_num:
continue
else:
yield start_idx, end_idx, self._make_key_by_hash(hash_val)
def decode_adaptor_prefill_pp(self, key, addr, size):
if self.partitions is None or len(self.partitions) == 1:
return key, addr, size
new_key = []
new_addr = []
new_size = []
for i, (addr_list, size_list) in enumerate(zip(addr, size)):
start = 0
for j, part in enumerate(self.partitions):
# part * 2 because addr and size contain both k and v
end = len(addr_list) if j == len(
self.partitions) - 1 else start + part * 2
new_str = key[i].replace( # type: ignore[attr-defined]
"@pp_rank:0", f"@pp_rank:{j}", 1)
new_key.append(new_str)
new_addr.append(addr_list[start:end])
new_size.append(size_list[start:end])
start = end
return new_key, new_addr, new_size
#Parameters related to the connector metadata
@dataclass
class LoadSpec:
# Number of tokens cached in vLLM
vllm_cached_tokens: int
# Number of tokens that are cached in kvpool
kvpool_cached_tokens: int
# Whether the scheduler allow us to load the tokens
can_load: bool
token_len: int = 0
@dataclass
class RequestTracker:
# Request id
req_id: str
# The token ids that has been scheduled so far
token_len: int
# The block ids that has been allocated so far
# NOTE: allocated blocks could be more than the number of tokens
# FIXME: need to check whether the block ids will be changed after
# preemption
allocated_block_ids: list[int]
# The number of tokens that has been savd
num_saved_tokens: int = 0
@staticmethod
def from_new_request(
new_request: "NewRequestData",
num_tokens_to_compute: int,
) -> "RequestTracker":
"""Create the request tracker from a new request.
Args:
new_request (NewRequestData): the new request data.
num_tokens_to_compute (int): the number of tokens that will
be 'computed', including the `num_computed_tokens` (vLLM's
local cache hit) and new tokens that will be scheduled.
"""
unfolded_block_ids = []
if not isinstance(new_request.block_ids[0], list):
unfolded_block_ids = new_request.block_ids.copy()
else:
unfolded_block_ids = new_request.block_ids[0].copy()
return RequestTracker(
req_id=new_request.req_id,
token_len=num_tokens_to_compute,
allocated_block_ids=unfolded_block_ids,
num_saved_tokens=0,
)
def update(
self,
new_block_ids: Union[tuple[list[int], ...], list[int]],
) -> None:
"""Update the request tracker when a running request is
scheduled again
"""
if len(new_block_ids) == 0:
new_block_ids = []
elif isinstance(new_block_ids, tuple):
new_block_ids = new_block_ids[0]
elif isinstance(new_block_ids, list):
pass
else:
raise ValueError(
f"Unsupported new_block_ids type {type(new_block_ids)}")
self.allocated_block_ids.extend(new_block_ids)
@dataclass
class ReqMeta:
# Request id
req_id: str
# Request tokens
token_len_chunk: int
block_ids: list[int]
block_hashes: list[BlockHash]
can_save: Optional[bool] = None
# load_spec
load_spec: Optional[LoadSpec] = None
is_last_chunk: Optional[bool] = None
current_event: Optional[torch.npu.Event] = None
@staticmethod
def from_request_tracker(
tracker: RequestTracker,
block_size: int,
load_spec: Optional[LoadSpec] = None,
skip_save: Optional[bool] = False,
block_hashes: list[BlockHash] = [],
is_last_chunk: Optional[bool] = None,
discard_partial_chunks: bool = True,
) -> Optional["ReqMeta"]:
"""Create the request metadata from a request tracker.
Args:
tracker (RequestTracker): the request tracker.
block_size (int): the block size in vLLM.
load_spec (Optional[LoadSpec]): the load spec for KV cache loading.
skip_save (bool): whether to skip the save operation.
discard_partial_chunks (bool): whether to discard partial chunks.
Returns:
the request metadata if we need to perform load/save
operations, None otherwise.
"""
input_token_len = tracker.token_len
# For save operation: do not save if the following condition is met
# 1. has already been saved before (num_saved_tokens > 0)
# 2. number of unsaved tokens is not reached the chunk boundary
chunk_boundary = (cdiv(tracker.num_saved_tokens + 1, block_size) *
block_size if discard_partial_chunks else 0)
# Calculate number of tokens to save based on discard_partial_chunks
# setting
num_tokens_to_save = ((input_token_len // block_size * block_size)
if discard_partial_chunks else input_token_len)
skip_save = skip_save or num_tokens_to_save < chunk_boundary
if skip_save and load_spec is None:
return None
# If we need to save, update the number of saved tokens
if not skip_save:
tracker.num_saved_tokens = num_tokens_to_save
# # For load operation: check whether the request is scheduled to load
if load_spec is not None and load_spec.can_load:
logger.debug(
"Scheduled to load %d tokens for request %s",
load_spec.kvpool_cached_tokens,
tracker.req_id,
)
else:
# Do not load if not in `can_load` state
load_spec = None
logger.debug(
f"request:{tracker.req_id}, meta save spec:{not skip_save}, meta load spec:{load_spec}"
)
return ReqMeta(
req_id=tracker.req_id,
token_len_chunk=num_tokens_to_save,
block_ids=tracker.allocated_block_ids,
can_save=not skip_save,
load_spec=load_spec,
block_hashes=block_hashes,
is_last_chunk=is_last_chunk,
)
class AscendConnectorMetadata(KVConnectorMetadata):
def __init__(self, unfinished_request_ids):
self.requests = []
self.unfinished_request_ids = unfinished_request_ids
def add_request(self, req_meta: ReqMeta) -> None:
"""Add a request to the metadata.
Args:
req_meta (ReqMeta): the request metadata.
"""
self.requests.append(req_meta)
@dataclass
class LasyerMultiBlockReqMeta:
req_id: str
keys: List[LayerPoolKey]
starts: List[int]
ends: list[int]
block_ids: list[int]
layer_id: int
is_last_chunk: Optional[bool] = True
current_event: Optional[torch.npu.Event] = None

View File

@@ -0,0 +1,361 @@
import queue
import threading
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from typing import Any
import torch
from vllm.logger import logger
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import \
Backend
# isort: off
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import (
ChunkedTokenDatabase,
LasyerMultiBlockReqMeta,
ReqMeta,
)
# isort: on
class KVTransferThread(threading.Thread):
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
block_size: int, tp_rank: int, dcp_size: int,
ready_event: threading.Event, name: str):
super().__init__(daemon=True, name=name)
self.m_store = m_store
self.ready_event = ready_event
self.block_size = block_size
self.tp_rank = tp_rank
self.dcp_size = dcp_size
self.token_database = token_database
self.done_task_lock = threading.Lock()
self.request_queue: queue.Queue[Any] = queue.Queue()
# TODO(jianzs): make this configurable
self.executor = ThreadPoolExecutor(max_workers=32)
self.finished_requests: set[str] = set()
def add_request(
self,
request: ReqMeta,
) -> torch.Tensor:
self.request_queue.put(request)
def get_and_clear_finished_requests(self) -> set[str]:
"""
Get and clear the requests that have been completed.
Returns:
A set of request IDs that have been completed.
"""
with self.done_task_lock:
finished_requests = self.finished_requests.copy()
self.finished_requests.clear()
return finished_requests
def set_finished_request(self, req_id):
with self.done_task_lock:
self.finished_requests.add(req_id)
def run(self):
"""Run the thread to handle KV cache transfer requests."""
self.m_store.set_device()
self.ready_event.set()
while True:
try:
request_data = self.request_queue.get()
if request_data is None:
logger.warning("Received a None request!")
self.request_queue.task_done()
continue
self._handle_request(request_data)
except Exception as e:
logger.error(f"Error in KVCacheTransferThread: {e}")
def _handle_request(self, req_meta: Any):
pass
def lookup(
self,
keys: list[str],
) -> int:
"""
Checks the existence of KV cache of the tokens from the cache engine.
:param tokens: the input tokens, with shape [seq_len]
:return: An int indicating how many prefix tokens are cached.
"""
try:
res = self.m_store.exists(keys) # type: ignore[assignment]
for index, value in enumerate(res): # type: ignore[arg-type]
if value != 1:
return index
# all tokens where found, return the maximal end
except Exception as e:
logger.error(f"Remote connection failed in contains: {e}")
return 0
return len(keys)
class KVCacheStoreSendingThread(KVTransferThread):
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
block_size: int, tp_rank: int, dcp_size: int, put_step: int,
kv_role: str, ready_event: threading.Event):
super().__init__(m_store,
token_database,
block_size,
tp_rank,
dcp_size,
ready_event,
name="KVCacheSendingThread")
self.put_step = put_step
self.kv_role = kv_role
self.stored_requests = defaultdict[str, int](int)
def add_stored_request(self, req_id: str):
with self.done_task_lock:
self.stored_requests[req_id] += 1
def delete_finished_stored_request(self, req_id: str):
with self.done_task_lock:
if req_id in self.stored_requests:
del self.stored_requests[req_id]
def _handle_request(self, req_meta: ReqMeta):
token_len = req_meta.token_len_chunk
block_ids = req_meta.block_ids
req_id = req_meta.req_id
current_event = req_meta.current_event
starts = []
ends = []
keys = []
for start, end, key in self.token_database.process_tokens(
token_len, req_meta.block_hashes):
starts.append(start)
ends.append(end)
keys.append(key.to_string())
if not self.dcp_size > 1:
starts = starts[self.tp_rank % self.put_step::self.put_step]
ends = ends[self.tp_rank % self.put_step::self.put_step]
keys = keys[self.tp_rank % self.put_step::self.put_step]
if not keys:
with self.done_task_lock:
self.stored_requests[req_id] -= 1
return
skip_block_num = self.lookup(keys)
if skip_block_num == len(keys):
with self.done_task_lock:
self.stored_requests[req_id] -= 1
return
starts = starts[skip_block_num:]
ends = ends[skip_block_num:]
keys = keys[skip_block_num:]
logger.info(
"Storing KV cache for %d out of %d blocks "
"(skip_block_num=%d) for request %s",
len(keys),
token_len // self.block_size,
skip_block_num,
req_id,
)
if keys:
"""
Note: Due to a bug in ADXL, calling current_event.synchronize() may occasionally hang.
This issue will be fixed in CANN version 8.5.rc1.
You can manually build the master branch of the project at https://gitcode.com/cann/hixl
to resolve this issue before the 8.5.RC1 release.
"""
addrs = []
sizes = []
for index, start in enumerate(starts):
addr, size, _ = self.token_database.prepare_value(
start, ends[index], block_ids)
addrs.append(addr)
sizes.append(size)
if self.kv_role == "kv_consumer":
keys, addrs, sizes = self.token_database.decode_adaptor_prefill_pp(
keys, addrs, sizes)
if current_event is not None:
current_event.synchronize()
self.m_store.put(keys, addrs, sizes)
with self.done_task_lock:
self.stored_requests[req_id] -= 1
self.request_queue.task_done()
class KVCacheStoreRecvingThread(KVTransferThread):
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
block_size: int, tp_rank: int, dcp_size: int,
ready_event: threading.Event):
super().__init__(m_store,
token_database,
block_size,
tp_rank,
dcp_size,
ready_event,
name="KVCacheStoreRecvingThread")
def _handle_request(self, req_meta: ReqMeta):
token_len = req_meta.load_spec.token_len # type: ignore[union-attr]
req_id = req_meta.req_id
mask_num = (
req_meta.load_spec.vllm_cached_tokens # type: ignore[union-attr]
// self.block_size * self.block_size)
addr_list = []
size_list = []
key_list = []
for start, end, key in self.token_database.process_tokens(
token_len, req_meta.block_hashes, mask_num):
addr, size, _ = self.token_database.prepare_value(
start, end, req_meta.block_ids)
key_list.append(key.to_string())
addr_list.append(addr)
size_list.append(size)
key_list_c = key_list[self.tp_rank %
len(key_list):] + key_list[:self.tp_rank %
len(key_list)]
addr_list_c = addr_list[self.tp_rank %
len(addr_list):] + addr_list[:self.tp_rank %
len(addr_list)]
size_list_c = size_list[self.tp_rank %
len(size_list):] + size_list[:self.tp_rank %
len(size_list)]
self.m_store.get(key_list_c, addr_list_c, size_list_c)
self.set_finished_request(req_id)
self.request_queue.task_done()
class KVCacheStoreLayerSendingThread(KVTransferThread):
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
block_size: int, tp_rank: int, dcp_size: int, put_step: int,
ready_event: threading.Event, num_layers: int):
super().__init__(m_store,
token_database,
block_size,
tp_rank,
dcp_size,
ready_event,
name="KVCacheStoreLayerSendingThread")
self.final_layer_id = num_layers - 1
self.put_step = put_step
def add_request( # type: ignore[override]
self, req_meta: ReqMeta) -> torch.Tensor:
self.request_queue.put(req_meta)
def _handle_request( # type: ignore[override]
self, req_meta: LasyerMultiBlockReqMeta):
starts = req_meta.starts
ends = req_meta.ends
keys = req_meta.keys
layer_id = req_meta.layer_id
current_event = req_meta.current_event
total_block = len(keys)
is_last_chunk = req_meta.is_last_chunk
if not self.dcp_size > 1:
starts = starts[self.tp_rank % self.put_step::self.put_step]
ends = ends[self.tp_rank % self.put_step::self.put_step]
keys = keys[self.tp_rank % self.put_step::self.put_step]
if not keys:
if is_last_chunk:
self.set_finished_request(req_meta.req_id)
return
key_list = []
for key in keys:
key_list.append(key.to_string())
skip_block_num = self.lookup(key_list)
if skip_block_num == len(key_list):
if is_last_chunk and layer_id == self.final_layer_id:
self.set_finished_request(req_meta.req_id)
return
starts = starts[skip_block_num:]
ends = ends[skip_block_num:]
key_list = key_list[skip_block_num:]
addr_list = []
size_list = []
for index, key in enumerate(key_list):
addr, size = self.token_database.prepare_value_layer(
starts[index], ends[index], req_meta.block_ids, layer_id)
addr_list.append(addr)
size_list.append(size)
if current_event is not None:
current_event.synchronize()
self.m_store.put(key_list, addr_list, size_list)
if layer_id == self.final_layer_id and is_last_chunk:
self.set_finished_request(req_meta.req_id)
self.request_queue.task_done()
logger.info(
"Storing KV cache for %d out of %d blocks "
"(skip_block_num=%d) for request %s",
len(keys),
total_block,
skip_block_num,
req_meta.req_id,
)
class KVCacheStoreLayerRecvingThread(KVTransferThread):
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
block_size: int, tp_rank: int, dcp_size: int,
ready_event: threading.Event, get_event: threading.Event):
super().__init__(m_store,
token_database,
block_size,
tp_rank,
dcp_size,
ready_event,
name="KVCacheStoreLayerRecvingThread")
self.get_event = get_event
def add_request( # type: ignore[override]
self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor:
self.request_queue.put(req_meta)
def _handle_request( # type: ignore[override]
self, req_meta: LasyerMultiBlockReqMeta):
addr_list = []
size_list = []
key_list = []
for index, key in enumerate(req_meta.keys):
addr, size = self.token_database.prepare_value_layer(
req_meta.starts[index], req_meta.ends[index],
req_meta.block_ids, req_meta.layer_id)
key_list.append(key.to_string())
addr_list.append(addr)
size_list.append(size)
key_list_c = key_list[self.tp_rank %
len(key_list):] + key_list[:self.tp_rank %
len(key_list)]
addr_list_c = addr_list[self.tp_rank %
len(addr_list):] + addr_list[:self.tp_rank %
len(addr_list)]
size_list_c = size_list[self.tp_rank %
len(size_list):] + size_list[:self.tp_rank %
len(size_list)]
self.m_store.get(key_list_c, addr_list_c, size_list_c)
self.request_queue.task_done()
self.get_event.set()

View File

@@ -0,0 +1,272 @@
import math
import os
import pickle
from dataclasses import dataclass
from multiprocessing.shared_memory import SharedMemory
from typing import Any, Callable, Optional
import torch
import vllm.envs as envs
import zmq
from vllm.config import KVTransferConfig, VllmConfig
from vllm.logger import logger
from vllm.utils.network_utils import make_zmq_socket
from vllm.utils.torch_utils import get_dtype_size
from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec
from vllm_ascend.distributed.kv_transfer.kv_pool.cpu_offload.cpu_kv_cache_manager import \
CPUKVCacheManager
@dataclass
class MLAConfig:
nope_dim: int
rope_dim: int
def get_cpu_offload_connector(vllm_config: VllmConfig) -> KVTransferConfig:
if vllm_config.kv_transfer_config is not None:
kv_transfer_config = vllm_config.kv_transfer_config
if kv_transfer_config.kv_connector == "CPUOffloadingConnector":
return kv_transfer_config
elif kv_transfer_config.kv_connector == "MultiConnector":
ktcs = kv_transfer_config.kv_connector_extra_config.get(
"connectors")
for ktc in ktcs:
kv_transfer_config = KVTransferConfig(**ktc)
if kv_transfer_config.kv_connector == "CPUOffloadingConnector":
return kv_transfer_config
return None
class MetadataServer:
METADATA_SERVER_ADDRESS = f"ipc://{envs.VLLM_RPC_BASE_PATH}/metadata.ipc"
DEFAULT_CPU_SWAP_SPACE_GB = 800
class ZMQRPCClient:
def __init__(self, identity=f"worker-{os.getpid()}"):
logger.info(f"metadata client for worker {identity} started")
self.ctx = zmq.Context() # type: ignore
self.socket = make_zmq_socket(
self.ctx,
MetadataServer.METADATA_SERVER_ADDRESS,
zmq.DEALER, # type: ignore
bind=False,
identity=identity.encode(),
linger=0)
def call(self, func_name: str, *args, **kwargs) -> Any:
request = (func_name, args, kwargs)
self.socket.send(b"", zmq.SNDMORE) # type: ignore
self.socket.send(pickle.dumps(request))
_ = self.socket.recv()
response = pickle.loads(self.socket.recv())
result, error = response
if error:
logger.exception(f"call metadata sever error: {error}")
raise error
if func_name == "init_cpu_kv_caches":
(memory_dict, layer_size, layer_dtype, mla_config) = result
# shared_memory_dict is recorded in self to close
self.shared_memory_dict = memory_dict
result = {}
for key, shm in memory_dict.items():
tensor = torch.frombuffer(
shm.buf, dtype=layer_dtype).reshape(layer_size)
if mla_config is not None:
tensor = tensor.split(
[mla_config.nope_dim, mla_config.rope_dim], dim=-1)
result[key] = tensor
return result
def __del__(self):
# will be finalized by outer process
self.socket.close()
self.ctx.term()
if hasattr(self, 'shared_memory_dict'):
for shm in self.shared_memory_dict.values():
shm.close()
def __init__(self, vllm_config: VllmConfig):
self.world_size = vllm_config.parallel_config.world_size
self.pipeline_parallel_size = vllm_config.parallel_config.pipeline_parallel_size
kv_transfer_config = get_cpu_offload_connector(vllm_config)
assert kv_transfer_config is not None
available_memory_gb = kv_transfer_config.get_from_extra_config(
"cpu_swap_space_gb", MetadataServer.DEFAULT_CPU_SWAP_SPACE_GB)
self.available_memory = available_memory_gb * 1024 * 1024 * 1024
logger.info(f"cpu swap space: {self.available_memory} bytes")
self.ctx = zmq.Context() # type: ignore
self.socket = make_zmq_socket(
self.ctx,
MetadataServer.METADATA_SERVER_ADDRESS,
zmq.ROUTER, # type: ignore
bind=True,
linger=0)
self.functions: dict[str, Callable] = {
"init_cpu_kv_caches": self.init_cpu_kv_caches,
"post_init": self.post_init,
"ready": self.ready,
}
self.shared_memory = {} # type: ignore
self.num_cpu_blocks = -1
@staticmethod
def _safe_create_shared_memory(name: str, size: int) -> SharedMemory:
try:
existing_shm = SharedMemory(name=name, create=False)
existing_shm.close()
existing_shm.unlink()
except FileNotFoundError:
pass
return SharedMemory(name=name, create=True, size=size)
def ready(self):
return True
def init_cpu_kv_caches(
self,
pp_rank: int,
tp_rank: int,
kv_cache_specs: dict[str, AttentionSpec],
mla_config: MLAConfig,
) -> tuple[dict[str, SharedMemory], tuple[int, ...], torch.dtype,
MLAConfig]:
logger.info(f"receive pp rank: {pp_rank}, tp rank: {tp_rank}")
# follow the assumption that each layer has the same spec
layer = next(iter(kv_cache_specs.values()))
assert all([
layer.page_size_bytes == any.page_size_bytes
for any in kv_cache_specs.values()
])
use_mla = isinstance(layer, MLAAttentionSpec)
# mla shares the same kv cache among different tp
if use_mla:
tp_rank = 0
if (pp_rank, tp_rank) in self.shared_memory:
return self.shared_memory[(pp_rank, tp_rank)]
available_memory = self.available_memory
shared_memory_dict = {}
if use_mla:
available_memory //= self.pipeline_parallel_size
available_memory //= len(kv_cache_specs)
num_blocks = available_memory // layer.page_size_bytes
layer_size = (num_blocks, layer.block_size, layer.num_kv_heads,
layer.head_size) # type: ignore
else:
available_memory //= self.world_size
available_memory //= len(kv_cache_specs)
num_blocks = available_memory // layer.page_size_bytes
layer_size = (2, num_blocks, layer.block_size, layer.num_kv_heads,
layer.head_size) # type: ignore
nbytes = math.prod(layer_size) * get_dtype_size(layer.dtype)
for layer_name in kv_cache_specs.keys():
# only this format can share during ZeroMQ+pickle
shared_memory_dict[
layer_name] = MetadataServer._safe_create_shared_memory(
f"cpu_kv_cache_{pp_rank}_{tp_rank}_{layer_name}", nbytes)
if use_mla:
assert mla_config is not None
assert layer.head_size == mla_config.rope_dim + mla_config.nope_dim
self.shared_memory[(pp_rank,
tp_rank)] = (shared_memory_dict, layer_size,
layer.dtype, mla_config)
else:
self.shared_memory[(pp_rank,
tp_rank)] = (shared_memory_dict, layer_size,
layer.dtype, None)
if self.num_cpu_blocks == -1 or num_blocks < self.num_cpu_blocks:
self.num_cpu_blocks = num_blocks
self.layer = layer
return self.shared_memory[(pp_rank, tp_rank)]
def post_init(self):
# different processors in data parallel may call multiple times
if hasattr(self, 'cpu_block_manager'):
return
# do shared_memory() at least once
logger.info(f"assign cpu num blocks: {self.num_cpu_blocks}")
assert self.num_cpu_blocks >= 0
self.cpu_block_manager = CPUKVCacheManager(self.layer,
self.num_cpu_blocks)
self.functions.update({
"get_matched_num_and_touch":
self.cpu_block_manager.get_matched_num_and_touch,
"allocate_slots":
self.cpu_block_manager.allocate_slots,
"record_request_cache_and_free_slots":
self.cpu_block_manager.record_request_cache_and_free_slots,
"cache_and_free_slots":
self.cpu_block_manager.cache_and_free_slots,
})
def serve_step(self):
client_id = self.socket.recv()
_ = self.socket.recv()
raw_msg = self.socket.recv()
try:
func_name, args, kwargs = pickle.loads(raw_msg)
except Exception as e:
response = (None, Exception(f"Invalid request: {str(e)}"))
else:
if func_name in self.functions:
try:
result = self.functions[func_name](*args, **kwargs)
response = (result, None) # type: ignore
except Exception as e:
logger.exception(f"metadata execute error: {e}")
response = (None, e) # type: ignore
else:
response = (None, NameError(f"Function {func_name} not found"))
self.socket.send(client_id, zmq.SNDMORE) # type: ignore
self.socket.send(b"", zmq.SNDMORE) # type: ignore
self.socket.send(pickle.dumps(response))
def shutdown(self):
self.socket.close()
self.ctx.term()
socket_path = MetadataServer.METADATA_SERVER_ADDRESS.replace(
"ipc://", "")
if os.path.exists(socket_path):
os.remove(socket_path)
for cached in self.shared_memory.values():
for shm in cached[0].values():
shm.close()
shm.unlink()
class MetadataServerProc:
@staticmethod
def run_metadata_server(vllm_config: VllmConfig):
if (not vllm_config.cache_config.enable_prefix_caching
or get_cpu_offload_connector(vllm_config) is None):
return
shutdown_requested = False
def _signal_handler(signum, frame):
nonlocal shutdown_requested
if not shutdown_requested:
shutdown_requested = True
raise SystemExit()
# Either SIGTERM or SIGINT will terminate the worker
# signal.signal(signal.SIGTERM, _signal_handler)
# signal.signal(signal.SIGINT, _signal_handler)
metadata_server: Optional[MetadataServer] = None
try:
metadata_server = MetadataServer(vllm_config)
logger.info("Metadata server started.")
while True:
metadata_server.serve_step()
except SystemExit:
logger.info("Metadata server exiting.")
raise
except Exception as e:
logger.exception(f"Metadata server error: {e}.")
raise e
finally:
if metadata_server is not None:
metadata_server.shutdown()

View File

@@ -0,0 +1,331 @@
from typing import Any, Optional
import vllm.envs as envs
import zmq
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import \
KVConnectorMetadata
from vllm.logger import logger
from vllm.utils.network_utils import make_zmq_socket
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.kv_cache_utils import BlockHash
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.request import Request
from vllm.v1.serial_utils import MsgpackEncoder
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import (
AscendConnectorMetadata, LoadSpec, ReqMeta, RequestTracker)
class KVPoolScheduler:
def __init__(self, vllm_config: "VllmConfig", use_layerwise):
self.use_layerwise = use_layerwise
self.kv_role = vllm_config.kv_transfer_config.kv_role
self.consumer_is_to_load = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"consumer_is_to_load", False)
self.consumer_is_to_put = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"consumer_is_to_put", False)
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"load_async", False)
self.client = LookupKeyClient(vllm_config)
# request_id -> (vllm cached tokes, kvpool cached tokens)
self.load_specs: dict[str, LoadSpec] = {}
self.pcp_size = getattr(vllm_config.parallel_config,
"prefill_context_parallel_size", 1)
self.dcp_size = getattr(vllm_config.parallel_config,
"decode_context_parallel_size", 1)
self._block_size = vllm_config.cache_config.block_size
if self.pcp_size > 1:
self._block_size *= self.pcp_size
if self.dcp_size > 1:
self._block_size *= self.dcp_size
# request_id -> full_token_ids
self._request_trackers: dict[str, RequestTracker] = {}
# Whether to discard partial chunks
self._discard_partial_chunks = (
vllm_config.kv_transfer_config.get_from_extra_config(
"discard_partial_chunks", True))
self._unfinished_requests: dict[str, tuple[Request, list[int]]] = {}
self._unfinished_request_ids: set[str] = set()
def get_num_new_matched_tokens(
self,
request: "Request",
num_computed_tokens: int,
) -> tuple[int, bool]:
"""
Check for external KV cache hit.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
"""
if self.kv_role == "kv_consumer" and not self.consumer_is_to_load:
return 0, False
if self._discard_partial_chunks:
token_len = len(request.prompt_token_ids
) // self._block_size * self._block_size
else:
token_len = len(request.prompt_token_ids)
num_external_hit_tokens = self.client.lookup(token_len,
request.block_hashes)
if num_external_hit_tokens == request.num_tokens:
num_external_hit_tokens -= 1
if num_external_hit_tokens < num_computed_tokens:
need_to_allocate = 0
else:
need_to_allocate = num_external_hit_tokens - num_computed_tokens
logger.info(
"Reqid: %s, Total tokens %d, kvpool hit tokens: %d, need to load: %d",
request.request_id,
request.num_tokens,
num_external_hit_tokens,
need_to_allocate,
)
if need_to_allocate <= 0:
return 0, False
self.load_specs[request.request_id] = LoadSpec(
vllm_cached_tokens=num_computed_tokens,
kvpool_cached_tokens=num_external_hit_tokens,
can_load=False,
)
return need_to_allocate, self.load_async and not self.use_layerwise
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
"""
Update KVConnector state after temporary buffer alloc.
For SharedStorageConnector, update _request_needs_load
if the CacheManager this allocated blocks for us.
"""
local_block_ids = []
if num_external_tokens > 0:
local_block_ids = blocks.get_block_ids()[0]
self._unfinished_requests[request.request_id] = (request,
local_block_ids)
self._unfinished_request_ids.add(request.request_id)
if request.request_id not in self.load_specs:
# No KV tokens from external KV cache, return
return
if num_external_tokens == 0:
# No need to load anything
self.load_specs[request.request_id].can_load = False
return
assert (
num_external_tokens > 0 and num_external_tokens
== self.load_specs[request.request_id].kvpool_cached_tokens -
self.load_specs[request.request_id].vllm_cached_tokens
), (f"Mismatch in number of tokens: {num_external_tokens} vs "
f"{self.load_specs[request.request_id].kvpool_cached_tokens} - "
f"{self.load_specs[request.request_id].vllm_cached_tokens}"
f" for request {request.request_id}")
self.load_specs[request.request_id].can_load = True
def build_connector_meta(
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
"""Attach the connector metadata to the request object.
This function should NOT modify other fields in the scheduler_output
except the `kv_connector_metadata` field.
Also, calling this function will reset the state of the connector.
Args:
scheduler_output (SchedulerOutput): the scheduler output object.
"""
force_skip_save = (self.kv_role == "kv_consumer"
and not self.consumer_is_to_put)
for finished_req_id in scheduler_output.finished_req_ids:
self._request_trackers.pop(finished_req_id, None)
self._unfinished_requests.pop(finished_req_id, None)
self._unfinished_request_ids.discard(finished_req_id)
meta = AscendConnectorMetadata(self._unfinished_request_ids)
for request in scheduler_output.scheduled_new_reqs:
# Right now, we only load KV for new requests
load_spec = self.load_specs.pop(request.req_id, None)
num_tokens_to_compute = (
request.num_computed_tokens +
scheduler_output.num_scheduled_tokens[request.req_id])
request_tracker = RequestTracker.from_new_request(
request, num_tokens_to_compute)
self._request_trackers[request.req_id] = request_tracker
last_chunk_tokens_num = ((len(request.prompt_token_ids) //
self._block_size * self._block_size)
if self._discard_partial_chunks else len(
request.prompt_token_ids))
request_tuple = self._unfinished_requests.get(request.req_id)
request_real = request_tuple[0] # type: ignore[index]
req_meta = ReqMeta.from_request_tracker(
request_tracker,
self._block_size,
load_spec=load_spec,
skip_save=force_skip_save,
block_hashes=request_real.block_hashes,
is_last_chunk=request_tracker.token_len
>= last_chunk_tokens_num,
discard_partial_chunks=self._discard_partial_chunks,
)
if req_meta is not None:
meta.add_request(req_meta)
cached_reqs = scheduler_output.scheduled_cached_reqs
if not force_skip_save:
for i, req_id in enumerate(cached_reqs.req_ids):
request_tracker = self._request_trackers[req_id]
num_new_tokens = scheduler_output.num_scheduled_tokens[req_id]
req_tuple = self._unfinished_requests.get(req_id)
if req_tuple:
request = req_tuple[0]
num_current_tokens = request_tracker.token_len
new_token_ids = request.all_token_ids[
num_current_tokens:num_current_tokens + num_new_tokens]
request_tracker.token_len += len(new_token_ids)
else:
raise ValueError(
f"Request {req_id} is not in _unfinished_requests, "
f"but it is scheduled to be cached")
new_block_ids = cached_reqs.new_block_ids[i]
if not new_block_ids:
continue
request_tracker.update(new_block_ids)
last_chunk_tokens_num = ((len(request.prompt_token_ids) //
self._block_size * self._block_size)
if self._discard_partial_chunks else
len(request.prompt_token_ids))
req_meta = ReqMeta.from_request_tracker(
request_tracker,
self._block_size,
load_spec=None,
skip_save=force_skip_save,
block_hashes=request.block_hashes,
is_last_chunk=request_tracker.token_len
>= last_chunk_tokens_num,
discard_partial_chunks=self._discard_partial_chunks,
)
if req_meta is not None:
meta.add_request(req_meta)
request_ids = [
req.req_id for req in scheduler_output.scheduled_new_reqs
]
for request_id, (request,
block_ids) in self._unfinished_requests.items():
if request_id not in request_ids and request_id not in cached_reqs.req_ids:
load_spec = self.load_specs.pop(request_id, None)
if not load_spec:
continue
num_tokens_to_compute = load_spec.kvpool_cached_tokens
if (num_tokens_to_compute % self._block_size
!= 0) and (num_tokens_to_compute
== len(request.prompt_token_ids) - 1):
num_tokens_to_compute = num_tokens_to_compute + 1
request_tracker = RequestTracker(
req_id=request_id,
token_len=num_tokens_to_compute,
allocated_block_ids=block_ids,
num_saved_tokens=0,
)
self._request_trackers[request_id] = request_tracker
req_meta = ReqMeta.from_request_tracker(
request_tracker,
self._block_size,
load_spec=load_spec,
skip_save=None,
block_hashes=request.block_hashes,
discard_partial_chunks=self._discard_partial_chunks,
)
if req_meta is not None:
meta.add_request(req_meta)
return meta
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
"""
Once a request is finished, determine whether request blocks
should be freed now or will be sent asynchronously and freed later.
"""
if self.kv_role == "kv_consumer" and not self.consumer_is_to_put:
return False, None
tracker = self._request_trackers.get(request.request_id)
if tracker is not None and tracker.num_saved_tokens <= 0:
return False, None
delay_free_blocks = len(block_ids) > 0
if delay_free_blocks:
logger.info("Delaying free of %d blocks for request %s",
len(block_ids), request.request_id)
return delay_free_blocks, None
class LookupKeyClient:
def __init__(self, vllm_config: "VllmConfig"):
self.encoder = MsgpackEncoder()
self.ctx = zmq.Context() # type: ignore[attr-defined]
socket_path = get_zmq_rpc_path_lookup(vllm_config)
self.socket = make_zmq_socket(
self.ctx,
socket_path,
zmq.REQ, # type: ignore[attr-defined]
bind=False,
)
def lookup(self, token_len: int, block_hashes: list[BlockHash]) -> int:
hash_strs = [h.hex() for h in block_hashes]
hash_frames = self.encoder.encode(hash_strs)
token_len_bytes = token_len.to_bytes(4, byteorder="big")
all_frames = [token_len_bytes] + list(hash_frames)
self.socket.send_multipart(all_frames, copy=False)
resp = self.socket.recv()
result = int.from_bytes(resp, "big")
return result
def close(self):
self.socket.close(linger=0)
def get_zmq_rpc_path_lookup(vllm_config: "VllmConfig") -> str:
dp_rank = vllm_config.parallel_config.data_parallel_rank
base_url = envs.VLLM_RPC_BASE_PATH
# Default to 0 if not configured
rpc_port = 0
if vllm_config is not None:
extra_config = vllm_config.kv_transfer_config.kv_connector_extra_config
if "lookup_rpc_port" in extra_config:
rpc_port = extra_config["lookup_rpc_port"]
elif "mooncake_rpc_port" in extra_config:
rpc_port = extra_config["mooncake_rpc_port"]
logger.warning(
"It is recommended to use the lookup_rpc_port, as the mooncake_rpc_port will be removed in the future."
)
logger.debug("Base URL: %s, RPC Port: %s", base_url, rpc_port)
return f"ipc://{base_url}/lookup_rpc_port_{rpc_port}_dp_rank{dp_rank}"

View File

@@ -0,0 +1,621 @@
import math
import threading
from typing import Dict, Generator, Optional, Type
import torch
from vllm.config import VllmConfig
from vllm.distributed import (get_decode_context_model_parallel_rank,
get_decode_context_model_parallel_world_size,
get_pcp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.logger import logger
from vllm.v1.core.kv_cache_utils import BlockHash
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import \
Backend
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.memcache_backend import \
MemcacheBackend
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.mooncake_backend import \
MooncakeBackend
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import (
AscendConnectorMetadata, ChunkedTokenDatabase, KeyMetadata,
LasyerMultiBlockReqMeta, ReqMeta)
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.kv_transfer import (
KVCacheStoreLayerRecvingThread, KVCacheStoreLayerSendingThread,
KVCacheStoreRecvingThread, KVCacheStoreSendingThread, KVTransferThread)
backend_map: Dict[str, Type[Backend]] = {
"mooncake": MooncakeBackend,
"memcache": MemcacheBackend,
}
class KVPoolWorker:
#The main class for the cache engine.
def __init__(
self,
vllm_config: VllmConfig,
use_layerwize: bool,
):
model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
self.dp_rank = parallel_config.data_parallel_rank
self.use_mla = False
if (hasattr(model_config, "use_mla")
and isinstance(model_config.use_mla, bool)
and model_config.use_mla):
self.use_mla = True
self.use_layerwise = use_layerwize
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.pp_size = parallel_config.pipeline_parallel_size
self.pp_rank = (parallel_config.rank // self.tp_size) % self.pp_size
self.pcp_size = get_pcp_group().world_size
self.pcp_rank = get_pcp_group(
).rank_in_group if self.pcp_size > 1 else 0
self.dcp_size = get_decode_context_model_parallel_world_size()
self.dcp_rank = get_decode_context_model_parallel_rank(
) if self.dcp_size > 1 else 0
self.kv_role = vllm_config.kv_transfer_config.kv_role
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"load_async", False)
self.consumer_is_to_put = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"consumer_is_to_put", False)
self.backend = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"backend", "mooncake")
self.block_size = vllm_config.cache_config.block_size
if self.pcp_size > 1:
self.block_size *= self.pcp_size
if self.dcp_size > 1:
self.block_size *= self.dcp_size
self.current_layer = 0
self.num_layers = model_config.get_num_layers(parallel_config)
if self.use_mla:
self.num_kv_head = 1
else:
self.num_kv_head = model_config.get_total_num_kv_heads()
if self.num_kv_head < self.tp_size:
self.put_step = self.tp_size // self.num_kv_head
self.head_or_tp_rank = self.tp_rank // self.put_step
else:
self.head_or_tp_rank = self.tp_rank
self.put_step = 1
self.metadata = KeyMetadata(
model_config.model.rstrip('/').split('/')[-1],
self.head_or_tp_rank,
self.pcp_rank,
self.dcp_rank,
self.pp_rank,
)
partitions = None
if self.kv_role == "kv_consumer" and self.consumer_is_to_put:
num_hidden_layers = model_config.hf_text_config.num_hidden_layers
partition_list_str = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"prefill_pp_layer_partition", None)
prefill_pp_size = int(
vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"prefill_pp_size", 1))
if partition_list_str is not None:
try:
partitions = [
int(layer) for layer in partition_list_str.split(",")
]
except ValueError as err:
raise ValueError("Invalid partition string: {}".format(
partition_list_str)) from err
if len(partitions) != prefill_pp_size:
raise ValueError(
f"{len(partitions)=} does not match {prefill_pp_size=}."
)
if sum(partitions) != num_hidden_layers:
raise ValueError(
f"{sum(partitions)=} does not match {num_hidden_layers=}."
)
else:
layers_per_partition = num_hidden_layers // prefill_pp_size
partitions = [
layers_per_partition for _ in range(prefill_pp_size)
]
if remaining_layers := num_hidden_layers % prefill_pp_size:
for i in range(2, remaining_layers + 2):
partitions[-i] += 1
self.token_database = ChunkedTokenDatabase(self.metadata,
self.block_size,
self.use_mla, partitions)
real_backend = backend_map.get(self.backend.lower())
# be removed later
if self.backend == "mooncake":
self.head_or_tp_rank = self.tp_rank
self.put_step = 1
self.m_store = real_backend( # type: ignore[misc]
parallel_config)
self.kv_send_thread: Optional[KVTransferThread] = None
self.kv_recv_thread: Optional[KVTransferThread] = None
self.finished_store_req: set[str] = set()
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
_, first_kv_cache_tuple = next(iter(kv_caches.items()))
first_kv_cache = first_kv_cache_tuple[0]
# TODO(tms): Find a more robust way to detect and handle MLA
if self.use_mla:
# MLA case.[num_block, block_size, 1, hidden_dim]
self.num_blocks = first_kv_cache.shape[0]
block_rank = 3 # [block_size, latent_dim]
block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:]
block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:]
self.block_len = [
first_kv_cache[0].element_size() * math.prod(block_shape_norm),
first_kv_cache[1].element_size() * math.prod(block_shape_pe)
]
logger.info(
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s",
self.num_blocks, block_shape_norm, block_shape_pe)
else:
# [num_block, block_size, num_head, hidden_dim]
self.num_blocks = first_kv_cache.shape[0]
kv_elem_size = first_kv_cache.element_size()
block_rank = 3 # [block_size, kv_heads, head_dim]
block_shape = first_kv_cache.shape[-block_rank:]
self.block_len = [kv_elem_size * math.prod(block_shape)]
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks,
block_shape)
logger.info("Registering KV_Caches. use_mla: %s, shape %s",
self.use_mla, first_kv_cache.shape)
self.kv_caches = kv_caches
self.kv_caches_base_addr = []
ptrs = []
lengths = []
for cache_or_caches in kv_caches.values():
# Normalize to always be a list of caches
if self.use_mla:
for i, cache in enumerate(cache_or_caches, 0):
base_addr = cache.data_ptr()
self.kv_caches_base_addr.append(base_addr)
region_len = self.num_blocks * self.block_len[i % 2]
ptrs.append(base_addr)
lengths.append(region_len)
else:
cache_list = [cache_or_caches
] if self.use_mla else cache_or_caches
for cache in cache_list:
base_addr = cache.data_ptr()
self.kv_caches_base_addr.append(base_addr)
region_len = self.num_blocks * self.block_len[0]
ptrs.append(base_addr)
lengths.append(region_len)
self.m_store.register_buffer(ptrs, lengths)
self.token_database.set_kv_caches_base_addr(self.kv_caches_base_addr)
self.token_database.set_block_len(self.block_len)
if self.use_layerwise:
self.get_event = threading.Event()
if self.kv_role in ['kv_producer', 'kv_both']:
ready_event_sending = threading.Event()
self.kv_send_thread = KVCacheStoreLayerSendingThread(
self.m_store, self.token_database, self.block_size,
self.tp_rank, self.dcp_size, self.put_step,
ready_event_sending, self.num_layers)
self.kv_send_thread.start()
ready_event = threading.Event()
self.kv_recv_thread = KVCacheStoreLayerRecvingThread(
self.m_store, self.token_database, self.block_size,
self.tp_rank, self.dcp_size, ready_event, self.get_event)
self.kv_recv_thread.start()
ready_event.wait()
else:
if self.kv_role in ['kv_producer', 'kv_both'
] or self.consumer_is_to_put:
ready_event_sending = threading.Event()
self.kv_send_thread = KVCacheStoreSendingThread(
self.m_store, self.token_database, self.block_size,
self.tp_rank, self.dcp_size, self.put_step, self.kv_role,
ready_event_sending)
self.kv_send_thread.start()
if self.load_async:
ready_event = threading.Event()
self.kv_recv_thread = KVCacheStoreRecvingThread(
self.m_store, self.token_database, self.block_size,
self.tp_rank, self.dcp_size, ready_event)
self.kv_recv_thread.start()
ready_event.wait()
def start_load_kv(self, metadata: AscendConnectorMetadata):
self.current_layer = 0
self.layerwise_retrievers = []
for request in metadata.requests:
load_spec = request.load_spec
if load_spec is None or not load_spec.can_load: #load =0
continue
token_len = request.token_len_chunk
if (load_spec.kvpool_cached_tokens % self.block_size
!= 0) and (load_spec.kvpool_cached_tokens
== token_len - 1):
token_len = request.load_spec.kvpool_cached_tokens + 1
else:
token_len = request.load_spec.kvpool_cached_tokens
request.load_spec.token_len = token_len
if self.use_layerwise:
layerwise_retriever = self.retrieve_layer(request)
next(layerwise_retriever) # first layer load
self.layerwise_retrievers.append(layerwise_retriever)
else:
if self.load_async:
self.kv_recv_thread.add_request( # type: ignore[union-attr]
request, )
else:
addr_list = []
size_list = []
key_list = []
mask_num = (request.load_spec.vllm_cached_tokens //
self.block_size * self.block_size)
for start, end, key in self.token_database.process_tokens(
token_len, request.block_hashes, mask_num):
addr, size, _ = self.token_database.prepare_value(
start, end, request.block_ids)
key_list.append(key.to_string())
addr_list.append(addr)
size_list.append(size)
key_list_c = key_list[self.tp_rank % len(
key_list):] + key_list[:self.tp_rank % len(key_list)]
addr_list_c = addr_list[self.tp_rank %
len(addr_list
):] + addr_list[:self.tp_rank %
len(addr_list)]
size_list_c = size_list[self.tp_rank %
len(size_list
):] + size_list[:self.tp_rank %
len(size_list)]
self.m_store.get(key_list_c, addr_list_c, size_list_c)
def wait_for_layer_load(self) -> None:
for layerwise_retriever in self.layerwise_retrievers:
ret_token_mask = next(layerwise_retriever)
if self.current_layer == self.num_layers - 1:
assert ret_token_mask is not None
num_retrieved_tokens = ret_token_mask.sum().item()
logger.debug(f"Retrieved {num_retrieved_tokens} tokens")
def save_kv_layer(self,
connector_metadata: AscendConnectorMetadata) -> None:
if self.current_layer == 0:
self.layerwise_storers = []
current_event = None
for request in connector_metadata.requests:
can_save = request.can_save
if can_save is None or not can_save:
continue
current_event = torch.npu.Event()
current_event.record()
break
for request in connector_metadata.requests:
can_save = request.can_save
if can_save is None or not can_save:
continue
layerwise_storer = self.store_layer(request, current_event)
self.layerwise_storers.append(layerwise_storer)
for layerwise_storer in self.layerwise_storers:
try:
next(layerwise_storer)
except Exception:
raise
self.current_layer = self.current_layer + 1
def wait_for_save(self, connector_metadata: AscendConnectorMetadata):
current_event = None
for request in connector_metadata.requests:
can_save = request.can_save
if can_save is None or not can_save:
continue
current_event = torch.npu.Event()
current_event.record()
break
for request in connector_metadata.requests:
can_save = request.can_save
if can_save is None or not can_save:
continue
request.current_event = current_event
self.kv_send_thread.add_stored_request( # type: ignore[union-attr]
request.req_id)
self.kv_send_thread.add_request( # type: ignore[union-attr]
request, )
def retrieve_layer(
self,
request: ReqMeta,
) -> Generator[Optional[torch.Tensor], None, None]:
"""
Retrieve the KV cache in a layerwise manner.
:param torch.Tensor tokens: The tokens of the corresponding KV caches.
:param Optional[torch.Tensor] mask: The mask for the tokens. Should
have the same length as tokens. And the mask should ALWAYS be like
FFFFFTTTTTTT, where True means the tokens needs to be matched.
:param **kwargs: The additional arguments for the KV transfer which
will be passed into the npu_transfer.
return: A generator that yields Optional[torch.Tensor]. The tensor will
be the boolean mask indicating which tokens are retrieved and will
only be returned in the last iteration.
"""
token_len = request.token_len_chunk
mask_num = (
request.load_spec.vllm_cached_tokens # type: ignore[union-attr]
// self.block_size * self.block_size)
num_required_tokens = token_len - mask_num
ret_mask = torch.zeros(token_len, dtype=torch.bool, device="cpu")
starts = []
ends = []
keys = []
first_flag = True
for start, end, key in self.token_database.process_tokens(
token_len, request.block_hashes, mask_num):
keys_multi_layer = key.split_layers(self.num_layers)
starts.append(start)
ends.append(end)
keys.append(keys_multi_layer)
ret_mask[start:end] = True
if keys:
# Transpose the keys into layer major format
keys = [list(row) for row in zip(*keys)] # [num_layer,block_num]
for layer_id, keys_multi_chunk in enumerate(keys):
if not first_flag:
is_finish = self.get_event.wait(timeout=3) #try---cache
if not is_finish:
logger.info("Layerwise get failed")
self.get_event.clear()
req_meta = LasyerMultiBlockReqMeta(request.req_id,
keys_multi_chunk, starts,
ends, request.block_ids,
layer_id)
self.kv_recv_thread.add_request( # type: ignore[union-attr, call-arg]
req_meta) # type: ignore[union-attr, call-arg, arg-type]
first_flag = False
yield None
else:
# If no cache are found, we still need to yield to avoid
# `StopIteration`
for layer_id in range(self.num_layers):
yield None
retrieved_tokens = torch.sum(ret_mask)
logger.debug(f"Retrieved {retrieved_tokens} "
f"out of {num_required_tokens} "
f"out of total {token_len} tokens")
yield ret_mask
def store_layer(
self,
request: ReqMeta,
current_event: Optional[torch.npu.Event],
) -> Generator[None, None, None]:
"""
Store the KV cache in a layerwise manner.
:param torch.Tensor tokens: The tokens of the corresponding KV caches.
:param Optional[torch.Tensor] mask: The mask for the tokens. Should
have the same length as tokens. And the mask should ALWAYS be like
FFFFFTTTTTTT, where True means the tokens needs to be matched.
:param **kwargs: The additional arguments for the storage backend which
will be passed into the gpu_connector.
return: A generator that yields None. In the first iteration, the
generator allocates the memory objects for all layers and moves
the KV cache of the first layer from GPU to CPU. In the next
iterations, it moves the KV cache of layer i from GPU to the memory
objects (on CPU) and puts the memory objects of layer i-1 to the
storage backends. In the last iteration, it puts the memory objects
of the last layer to the storage backends.
"""
starts = []
ends = []
keys = []
for start, end, key in self.token_database.process_tokens(
request.token_len_chunk, request.block_hashes):
keys_multi_layer = key.split_layers(self.num_layers)
starts.append(start)
ends.append(end)
keys.append(keys_multi_layer) #[block_num,layer_num]
if keys:
keys = [list(row) for row in zip(*keys)] #[layer_num,block_num]
for layer_id, keys_multi_chunk in enumerate(keys):
req_meta = LasyerMultiBlockReqMeta(request.req_id,
keys_multi_chunk, starts,
ends, request.block_ids,
layer_id,
request.is_last_chunk,
current_event)
self.kv_send_thread.add_request( # type: ignore[union-attr, call-arg]
req_meta) # type: ignore[union-attr, call-arg, arg-type]
yield
else:
for layer_id in range(self.num_layers):
yield
def get_finished(self,
finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
done_sending = (
self.get_and_clear_finished_requests(
finished_req_ids # type: ignore[union-attr]
) if self.kv_role in ['kv_producer', 'kv_both']
or self.consumer_is_to_put else set())
done_recving = (
self.kv_recv_thread.
get_and_clear_finished_requests( # type: ignore[union-attr]
) if self.load_async else set())
logger.debug(
"Number of completed KV cache send requests: %d, receive "
"requests: %d, tp_rank:%d", len(done_sending), len(done_recving),
self.tp_rank)
return done_sending, done_recving
def get_and_clear_finished_requests(self, finished_req_ids) -> set[str]:
finished_sending = set()
for req_id in self.kv_send_thread.stored_requests.copy( # type: ignore[union-attr]
):
if self.kv_send_thread.stored_requests[ # type: ignore[union-attr]
req_id] == 0 and req_id in self.finished_store_req:
self.finished_store_req.remove(req_id)
finished_sending.add(req_id)
self.kv_send_thread.delete_finished_stored_request( # type: ignore[union-attr]
req_id)
for req_id in finished_req_ids:
req_remain_jobs = self.kv_send_thread.stored_requests.get( # type: ignore[union-attr]
req_id)
if req_remain_jobs == 0:
finished_sending.add(req_id)
self.kv_send_thread.delete_finished_stored_request( # type: ignore[union-attr]
req_id)
elif req_remain_jobs is not None:
self.finished_store_req.add(req_id)
return finished_sending
def lookup(
self,
token_len: int,
block_hashes: list[BlockHash],
use_layerwise: bool,
) -> int:
"""
Checks the existence of KV cache of the tokens from the cache engine.
:param tokens: the input tokens, with shape [seq_len]
:return: An int indicating how many prefix tokens are cached.
"""
end = 0
keys = []
try:
starts = []
for start, end, key in self.token_database.process_tokens(
token_len, block_hashes):
if use_layerwise:
keys_multi_layer = key.split_layers(self.num_layers)
for item in keys_multi_layer:
keys.append(item.to_string())
else:
keys.append(key.to_string())
starts.append(start)
res = self.m_store.exists(keys) # type: ignore[assignment]
if use_layerwise:
res = self.check_all_layers_exists(res, self.num_layers)
for index, value in enumerate(res): # type: ignore[arg-type]
if value != 1:
return starts[index]
# all tokens where found, return the maximal end
except Exception as e:
logger.error(f"Remote connection failed in contains: {e}")
return start
return end
def lookup_scheduler(
self,
token_len: int,
block_hashes: list[BlockHash],
use_layerwise: bool,
) -> int:
"""
Checks the existence of KV cache of the tokens from the cache engine.
:param tokens: the input tokens, with shape [seq_len]
:return: An int indicating how many prefix tokens are cached.
"""
end = 0
keys = []
try:
starts = []
for start, end, key in self.token_database.process_tokens(
token_len, block_hashes):
if use_layerwise:
keys_multi_layer = key.split_layers(self.num_layers)
for item in keys_multi_layer:
keys.append(item.to_string())
else:
keys.append(key.to_string())
starts.append(start)
multi_tp_keys = keys[:]
for i in range(1, min(self.tp_size, self.num_kv_head)):
for item in keys:
new_str = item.replace( # type: ignore[attr-defined]
"@head_or_tp_rank:0", f"@head_or_tp_rank:{i}", 1)
multi_tp_keys.append(new_str)
for i in range(1, self.pp_size):
for item in keys:
new_str = item.replace( # type: ignore[attr-defined]
"@pp_rank:0", f"@pp_rank:{i}", 1)
multi_tp_keys.append(new_str)
res = self.m_store.exists(
multi_tp_keys) # type: ignore[assignment]
num_block = len(keys)
if use_layerwise:
res = self.check_all_layers_exists(res, self.num_layers)
num_block = len(keys) // self.num_layers
multi_tp_values = [
res[i * num_block:(i + 1) * num_block] # type: ignore[index]
for i in range(
min(self.tp_size, self.num_kv_head) * self.pp_size)
]
index = self.find_min_first_non_one_index(multi_tp_values)
if index != -1:
return starts[index]
# all tokens where found, return the maximal end
except Exception as e:
logger.error(f"Remote connection failed in contains: {e}")
return start
return end
def check_all_layers_exists(self, res: list[int],
num_layers: int) -> list[int]:
total_chunks = len(res) // num_layers
result = []
for chunk_idx in range(total_chunks):
start = chunk_idx * num_layers
end = start + num_layers
chunk = res[start:end]
result.append(1 if all(x == 1 for x in chunk) else 0)
return result
def find_min_first_non_one_index(self, arr):
try:
return min(idx for row in arr for idx, val in enumerate(row)
if val != 1)
except ValueError:
return -1

View File

@@ -0,0 +1,203 @@
import time
from collections import defaultdict
from typing import Optional
from vllm.logger import logger
from vllm.utils.hashing import sha256
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
PrefixCachingMetrics)
from vllm.v1.core.single_type_kv_cache_manager import \
get_manager_for_kv_cache_spec
from vllm.v1.kv_cache_interface import KVCacheSpec
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request
class CPUCacheStats:
def __init__(self, enable_prefix_caching: bool, log_stats: bool = False):
self.enable_prefix_caching = enable_prefix_caching
self.log_stats = log_stats
self.prefix_cache_stats = PrefixCacheStats() if log_stats else None
self.cpu_prefix_cache_metrics = PrefixCachingMetrics()
self.time_sec = int(time.time())
def log(self):
current_time_sec = int(time.time())
# Log the prefix cache hit rate every 10 seconds.
if current_time_sec - self.time_sec >= 10:
self.time_sec = current_time_sec
logger.info("CPU Prefix cache hit rate: %.1f%%",
self.cpu_prefix_cache_metrics.hit_rate * 100)
def make_prefix_cache_stats(self) -> Optional[PrefixCacheStats]:
"""Get (and reset) the prefix cache stats.
Returns:
The current prefix caching stats, or None if logging is disabled.
"""
if not self.log_stats:
return None
stats = self.prefix_cache_stats
self.prefix_cache_stats = PrefixCacheStats()
return stats
def update(self, num_tokens, num_computed_tokens):
# Note the function is called by scheduler
if self.log_stats and self.enable_prefix_caching:
assert self.prefix_cache_stats is not None
self.prefix_cache_stats.requests += 1
self.prefix_cache_stats.queries += num_tokens
self.prefix_cache_stats.hits += num_computed_tokens
def set_cache_stats(self, num_tokens, num_computed_tokens):
assert self.prefix_cache_stats is not None
self.prefix_cache_stats.hits = num_computed_tokens
self.prefix_cache_stats.queries = num_tokens
self.prefix_cache_stats.requests = 1
class CPUKVCacheManager:
def __init__(
self,
kv_cache_spec: KVCacheSpec,
num_cpu_blocks: int,
caching_hash_algo: str = "builtin",
use_eagle: bool = False,
enable_kv_cache_events: bool = False,
) -> None:
self.block_size = kv_cache_spec.block_size
self.num_cpu_blocks = num_cpu_blocks
self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash
self.use_eagle = use_eagle
self.block_pool = BlockPool(self.num_cpu_blocks, True,
enable_kv_cache_events)
self.single_type_manager = get_manager_for_kv_cache_spec(
kv_cache_spec=kv_cache_spec,
block_pool=self.block_pool,
kv_cache_group_id=0,
)
# Record kv block hashes, avoid redundant computation.
self.req_to_block_hashes: defaultdict[
str, list[BlockHash]] = defaultdict(list)
# Record blocks touched in get_matched_num_and_touch().
self.req_to_computed_blocks: defaultdict[
str, list[KVCacheBlock]] = defaultdict(list)
# Record the request that failed to allocate.
self.req_failed_to_allocate: defaultdict[str, bool] = defaultdict(bool)
self.req_to_num_tokens: defaultdict[str, int] = defaultdict(int)
self.cpu_cache_stats = CPUCacheStats(enable_prefix_caching=True,
log_stats=True)
# Record request that will be free after finish sending
self.req_to_free: defaultdict[str, Request] = defaultdict(Request)
def get_matched_num_and_touch(self, request: Request) -> tuple[int, bool]:
# When the request requires prompt logprobs, we skip prefix caching.
if (request.sampling_params.prompt_logprobs is not None):
return 0, False
request_id = request.request_id
# The block hashes for the request may already be computed
# if the scheduler has tried to schedule the request before.
block_hashes = self.req_to_block_hashes[request_id]
if not block_hashes:
block_hashes = request.block_hashes
self.req_to_block_hashes[request_id] = block_hashes
max_cache_hit_length = request.num_tokens - 1
computed_blocks = self.single_type_manager.find_longest_cache_hit(
block_hashes=block_hashes,
max_length=max_cache_hit_length,
kv_cache_group_ids=[0],
block_pool=self.block_pool,
kv_cache_spec=self.single_type_manager.kv_cache_spec,
use_eagle=self.use_eagle,
)
num_computed_tokens = len(computed_blocks[0]) * self.block_size
self.req_to_computed_blocks[request_id] = computed_blocks[0]
# We should touch these blocks in the concurrent scenarios.
self.block_pool.touch(computed_blocks)
# cup prefix cache status set and log
assert self.cpu_cache_stats is not None and self.cpu_cache_stats.prefix_cache_stats is not None
self.cpu_cache_stats.set_cache_stats(request.num_tokens,
num_computed_tokens)
self.cpu_cache_stats.cpu_prefix_cache_metrics.observe(
self.cpu_cache_stats.prefix_cache_stats)
self.cpu_cache_stats.log()
return num_computed_tokens, False
def _release_ahead_touch(self, request_id: str):
computed_blocks = self.req_to_computed_blocks[request_id]
if computed_blocks:
self.single_type_manager.block_pool.free_blocks(
reversed(computed_blocks))
self.req_to_computed_blocks.pop(request_id, None)
def allocate_slots(self, req_to_num_tokens: dict[str, int],
unallocated_req_ids: set[str]) -> dict[str, list[int]]:
for request_id in unallocated_req_ids:
self._free_slots(request_id)
req_to_new_blocks = {}
for request_id, num_tokens in req_to_num_tokens.items():
if self.req_failed_to_allocate[request_id]:
continue
new_computed_blocks = self.req_to_computed_blocks[request_id]
num_blocks_to_allocate = (
self.single_type_manager.get_num_blocks_to_allocate(
request_id=request_id,
num_tokens=num_tokens,
new_computed_blocks=new_computed_blocks,
))
if num_blocks_to_allocate > self.block_pool.get_num_free_blocks():
self._release_ahead_touch(request_id)
self.req_failed_to_allocate[request_id] = True
continue
# Append the new computed blocks to the request blocks until now to
# avoid the case where the new blocks cannot be allocated.
self.single_type_manager.save_new_computed_blocks(
request_id, new_computed_blocks)
# Allocate new blocks but do not cache now.
new_blocks = self.single_type_manager.allocate_new_blocks(
request_id, num_tokens)
self.req_to_num_tokens[request_id] = num_tokens
# No need to release ref_cnt because we use officially.
self.req_to_computed_blocks.pop(request_id, None)
req_to_new_blocks[request_id] = [
block.block_id for block in new_computed_blocks + new_blocks
]
return req_to_new_blocks
def record_request_cache_and_free_slots(self, request: Request):
logger.debug(
f"record_request_cache_and_free_slots for request {request.request_id} in cpu_kv_cache_manager"
)
self.req_to_free[request.request_id] = request
def cache_and_free_slots(self, request_id: str):
logger.debug(
f"Cache and free slots for request {request_id} in cpu_kv_cache_manager"
)
if request_id not in self.req_to_free:
logger.Error(
f"request {request_id} not in req_to_free, maybe bug!")
return
request = self.req_to_free[request_id]
if not self.req_failed_to_allocate[request_id]:
self.single_type_manager.cache_blocks(
request,
self.req_to_num_tokens[request_id],
)
self._free_slots(request_id)
logger.debug(
f"delete request {request_id} in cpu_kv_cache_manager req_to_free")
del self.req_to_free[request_id]
def _free_slots(self, request_id: str):
# This function is designed to be reentrant.
self._release_ahead_touch(request_id)
self.single_type_manager.free(request_id)
self.req_to_block_hashes.pop(request_id, None)
self.req_to_computed_blocks.pop(request_id, None)
self.req_failed_to_allocate.pop(request_id, None)
self.req_to_num_tokens.pop(request_id, None)

View File

@@ -0,0 +1,528 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import queue
import threading
import time
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional, Sequence
import torch
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention, MLAAttention
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.logger import logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec,
MambaSpec, MLAAttentionSpec)
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.metadata import (
MetadataServer, MetadataServerProc, MLAConfig)
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request
@dataclass
class ReqMeta:
gpu_block_ids: list[int]
cpu_block_ids: list[int]
num_scheduled_tokens: int
num_computed_tokens: int
num_gpu_computed_tokens: int
num_cpu_computed_tokens: int
def update(self, other: "ReqMeta"):
self.gpu_block_ids.extend(other.gpu_block_ids)
self.cpu_block_ids.extend(other.cpu_block_ids)
self.num_scheduled_tokens = other.num_scheduled_tokens
self.num_computed_tokens = other.num_computed_tokens
self.num_gpu_computed_tokens = other.num_gpu_computed_tokens
self.num_cpu_computed_tokens = other.num_cpu_computed_tokens
@dataclass
class CPUOffloadingConnectorMetadata(KVConnectorMetadata):
requests: dict[str, ReqMeta]
finished_req_ids: set[str]
class CPUOffloadingConnector(KVConnectorBase_V1):
def __init__(self,
vllm_config: VllmConfig,
role: KVConnectorRole,
kv_cache_config: Optional[KVCacheConfig] = None):
if not vllm_config.cache_config.enable_prefix_caching:
self.connector_scheduler: Optional[
CPUOffloadingConnectorScheduler] = None
self.connector_worker: Optional[
CPUOffloadingConnectorWorker] = None
elif role == KVConnectorRole.SCHEDULER:
self.connector_scheduler = CPUOffloadingConnectorScheduler(
vllm_config)
self.connector_worker = None
elif role == KVConnectorRole.WORKER:
self.connector_scheduler = None
self.connector_worker = CPUOffloadingConnectorWorker(vllm_config)
# ==============================
# Worker-side methods
# ==============================
def bind_connector_metadata(
self, connector_metadata: KVConnectorMetadata) -> None:
if self.connector_worker is not None:
assert isinstance(connector_metadata,
CPUOffloadingConnectorMetadata)
self.connector_worker.bind_connector_metadata(connector_metadata)
def clear_connector_metadata(self) -> None:
assert self.connector_worker is not None
self.connector_worker.clear_connector_metadata()
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
if self.connector_worker is not None:
self.connector_worker.register_kv_caches(kv_caches)
def start_load_kv(self, forward_context: "ForwardContext",
**kwargs) -> None:
if self.connector_worker is not None:
self.connector_worker.start_load_kv()
def wait_for_layer_load(self, layer_name: str) -> None:
if self.connector_worker is not None:
self.connector_worker.wait_for_layer_load()
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", **kwargs) -> None:
pass
def wait_for_save(self):
pass
def get_finished(
self, finished_req_ids: set[str]
) -> tuple[Optional[set[str]], Optional[set[str]]]:
assert self.connector_worker is not None
return self.connector_worker.get_finished(), None
# Scheduler-side methods
# ==============================
def get_num_new_matched_tokens(
self, request: "Request",
num_computed_tokens: int) -> tuple[int, bool]:
if self.connector_scheduler is not None:
return self.connector_scheduler.get_num_new_matched_tokens(
request, num_computed_tokens)
return 0, False
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
if self.connector_scheduler is not None:
return self.connector_scheduler.update_state_after_alloc(request)
def build_connector_meta(
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
if self.connector_scheduler is not None:
return self.connector_scheduler.build_connector_meta(
scheduler_output)
return KVConnectorMetadata()
def request_finished(
self, request: "Request",
block_ids: list[int]) -> tuple[bool, Optional[dict[str, Any]]]:
if self.connector_scheduler is not None:
self.connector_scheduler.request_finished(request)
return True, None
class CPUOffloadingConnectorScheduler:
def __init__(self, vllm_config: VllmConfig):
logger.info("init CPUOffloadingConnectorScheduler")
self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size
self.use_mla = vllm_config.model_config.use_mla
self.num_gpu_computed_tokens: dict[str, int] = {}
self.num_cpu_computed_tokens: dict[str, int] = {}
self.allocated_req_ids: set[str] = set()
self.finished_req_ids: list[str] = []
self.zmq_rpc_client = MetadataServer.ZMQRPCClient()
self.zmq_rpc_client.call("post_init")
if vllm_config.kv_transfer_config is not None:
self.swap_in_threshold = vllm_config.kv_transfer_config.get_from_extra_config(
"swap_in_threshold", 0)
else:
self.swap_in_threshold = 0
logger.info(f"swap_in_threshold: {self.swap_in_threshold}")
def get_num_new_matched_tokens(
self, ori_request: "Request",
num_computed_tokens: int) -> tuple[int, bool]:
request = copy.deepcopy(ori_request)
request.get_hash_new_full_blocks = None
num_cpu_computed_tokens, load_async = self.zmq_rpc_client.call(
"get_matched_num_and_touch", request)
self.num_gpu_computed_tokens[request.request_id] = num_computed_tokens
self.num_cpu_computed_tokens[
request.request_id] = num_cpu_computed_tokens
if num_cpu_computed_tokens - num_computed_tokens >= self.swap_in_threshold:
return num_cpu_computed_tokens - num_computed_tokens, load_async
else:
return 0, load_async
def update_state_after_alloc(self, request: "Request"):
self.allocated_req_ids.add(request.request_id)
def build_connector_meta(
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
num_tokens = {}
# process scheduled_new_reqs
for req in scheduler_output.scheduled_new_reqs:
req_id = req.req_id
num_tokens[req_id] = (
req.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
# process scheduled_cached_reqs
cached_reqs = scheduler_output.scheduled_cached_reqs
for idx, req_id in enumerate(cached_reqs.req_ids):
num_tokens[req_id] = (
cached_reqs.num_computed_tokens[idx] +
scheduler_output.num_scheduled_tokens[req_id])
unallocated_req_ids = set(self.num_gpu_computed_tokens.keys() -
self.allocated_req_ids -
scheduler_output.num_scheduled_tokens.keys())
new_cpu_block_ids = self.zmq_rpc_client.call("allocate_slots",
num_tokens,
unallocated_req_ids)
metadata = CPUOffloadingConnectorMetadata(
requests={},
finished_req_ids=set(self.finished_req_ids),
)
for req in scheduler_output.scheduled_new_reqs:
req_id = req.req_id
gpu_block_ids = req.block_ids[0]
metadata.requests[req_id] = ReqMeta(
gpu_block_ids=[] if gpu_block_ids is None else gpu_block_ids,
cpu_block_ids=new_cpu_block_ids.get(req_id, []),
num_scheduled_tokens=scheduler_output.
num_scheduled_tokens[req_id],
num_computed_tokens=req.num_computed_tokens,
num_gpu_computed_tokens=self.num_gpu_computed_tokens[req_id],
num_cpu_computed_tokens=self.num_cpu_computed_tokens[req_id])
for idx, req_id in enumerate(cached_reqs.req_ids):
gpu_block_ids = cached_reqs.new_block_ids[idx]
metadata.requests[req_id] = ReqMeta(
gpu_block_ids=[] if gpu_block_ids is None else gpu_block_ids,
cpu_block_ids=new_cpu_block_ids.get(req_id, []),
num_scheduled_tokens=scheduler_output.
num_scheduled_tokens[req_id],
num_computed_tokens=cached_reqs.num_computed_tokens[idx],
num_gpu_computed_tokens=cached_reqs.num_computed_tokens[idx],
num_cpu_computed_tokens=cached_reqs.num_computed_tokens[idx])
self.num_gpu_computed_tokens.clear()
self.num_cpu_computed_tokens.clear()
self.allocated_req_ids.clear()
self.finished_req_ids.clear()
return metadata
def request_finished(self, ori_request: "Request"):
request = copy.deepcopy(ori_request)
request.get_hash_new_full_blocks = None
self.finished_req_ids.append(request.request_id)
# inform metadata server to record request, and free it after finish sending
self.zmq_rpc_client.call("record_request_cache_and_free_slots",
request)
class CPUOffloadingConnectorWorker:
def __init__(self, vllm_config: VllmConfig):
logger.info("init CPUOffloadingConnectorWorker")
self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size
self.pp_rank = get_pp_group().rank_in_group
self.tp_group = get_tp_group()
self.tp_rank = self.tp_group.rank_in_group
self.tp_world_size = self.tp_group.world_size
self.use_mla = vllm_config.model_config.use_mla
self.requests: dict[str, ReqMeta] = {}
self.load_stream = torch.npu.Stream()
self.save_stream = torch.npu.Stream()
self.zmq_rpc_client = MetadataServer.ZMQRPCClient()
self.load_block_mapping: list[tuple[int, int]] = []
self.save_input_queue: queue.Queue[tuple[str, ReqMeta]] = queue.Queue()
self.save_output_queue: queue.Queue[str] = queue.Queue()
self.save_thread = threading.Thread(target=self._save_listener)
self.save_thread.start()
self.done_sending_count: defaultdict[str, int] = defaultdict(int)
# start metadata server to init cpu_kv_cache_manager and handle rpc requests
# all dp shared the same metadata server, only start the process on data_rank 0
if vllm_config.parallel_config.data_parallel_rank == 0 and self.tp_rank == 0 and self.pp_rank == 0:
config = VllmConfig()
config.cache_config = vllm_config.cache_config
config.parallel_config = vllm_config.parallel_config
config.kv_transfer_config = vllm_config.kv_transfer_config
self.init_metadata_server(config)
self._wait_for_metadata_process_start()
def init_metadata_server(self, vllm_config: VllmConfig):
self.metadata_thread = threading.Thread(
target=MetadataServerProc.run_metadata_server,
args=(vllm_config, ),
)
self.metadata_thread.daemon = True
self.metadata_thread.start()
def _wait_for_metadata_process_start(self):
# TODO: wait for metadata server to start, add a rpc to check if ready
while True:
try:
if self.zmq_rpc_client.call("ready"):
break
except Exception as e:
logger.info(f"wait for metadata server to start, error: {e}")
time.sleep(1)
def bind_connector_metadata(
self, connector_metadata: CPUOffloadingConnectorMetadata) -> None:
for req_id, req in connector_metadata.requests.items():
if req_id in self.requests:
self.requests[req_id].update(req)
req = self.requests[req_id]
else:
self.requests[req_id] = req
for i in range(req.num_gpu_computed_tokens // self.block_size,
req.num_computed_tokens // self.block_size):
self.load_block_mapping.append(
(req.cpu_block_ids[i], req.gpu_block_ids[i]))
for req_id in connector_metadata.finished_req_ids:
if req_id in self.requests:
self.save_input_queue.put((req_id, self.requests[req_id]))
def clear_connector_metadata(self) -> None:
self.load_block_mapping.clear()
def register_kv_caches(self, kv_caches: dict[str, Sequence[torch.Tensor]]):
self.gpu_kv_caches = kv_caches
model_config = self.vllm_config.model_config
mla_config: Optional[MLAConfig] = None
if model_config.use_mla:
mla_config = MLAConfig(
model_config.hf_text_config.kv_lora_rank,
model_config.hf_text_config.qk_rope_head_dim)
self.cpu_kv_caches = list(
self.zmq_rpc_client.call(
"init_cpu_kv_caches",
self.pp_rank,
self.tp_rank,
get_kv_cache_spec(self.vllm_config),
mla_config,
).values())
def start_load_kv(self) -> None:
self.current_layer = 0
self.gpu_kv_caches_load_iter = iter(self.gpu_kv_caches.values())
self.load_kv_layer(0)
def wait_for_layer_load(self) -> None:
# TODO: Replace with `torch.npu.current_stream().wait_stream(self.load_stream)` after fixing the bug.
self.load_stream.synchronize()
self.current_layer += 1
self.load_kv_layer(self.current_layer)
def load_kv_layer(self, layer: int):
if layer == len(self.gpu_kv_caches):
return
gpu_kv_caches = next(self.gpu_kv_caches_load_iter)
cpu_kv_caches = self.cpu_kv_caches[layer]
with torch.npu.stream(self.load_stream):
for cpu_block_id, gpu_block_id in self.load_block_mapping:
for gpu_layer_part, cpu_layer_part in zip(
gpu_kv_caches, cpu_kv_caches):
gpu_layer_part[gpu_block_id].copy_(
cpu_layer_part[cpu_block_id], non_blocking=True)
def get_finished(self) -> set[str]:
done_sending: set[str] = set()
while True:
try:
id = self.save_output_queue.get_nowait()
except queue.Empty:
break
done_sending.add(id)
for id in done_sending:
del self.requests[id]
if self.tp_world_size == 1:
return done_sending
if self.tp_rank == 0:
for req_id in done_sending:
self.done_sending_count[req_id] += 1
other_ranks_finished_ids: list[str] = []
for i in range(1, self.tp_world_size):
other_ranks_finished_ids.extend(
self.tp_group.recv_object(src=i))
for req_id in other_ranks_finished_ids:
self.done_sending_count[req_id] += 1
all_done_sending: set[str] = set()
for req_id in list(self.done_sending_count.keys()):
if self.done_sending_count[req_id] == self.tp_world_size:
del self.done_sending_count[req_id]
all_done_sending.add(req_id)
# release cpu_kv_cache after request sending finished
# to avoid rpc blocking, use thread to call rpc asynchronously
sending_finished_thread = threading.Thread(
target=self._sending_finished, args=(all_done_sending, ))
sending_finished_thread.daemon = True
sending_finished_thread.start()
return all_done_sending
else:
self.tp_group.send_object(done_sending, dst=0)
return done_sending
def _sending_finished(self, all_done_sending):
for req_id in all_done_sending:
logger.debug(f"call cache_and_free_slots for req_id: {req_id}")
self.zmq_rpc_client.call("cache_and_free_slots", req_id)
def _save_listener(self):
save_block_mapping = []
while True:
req_id, req = self.save_input_queue.get()
for i in range(
req.num_cpu_computed_tokens // self.block_size,
min((req.num_computed_tokens + req.num_scheduled_tokens) //
self.block_size, len(req.cpu_block_ids))):
save_block_mapping.append(
(req.gpu_block_ids[i], req.cpu_block_ids[i]))
with torch.npu.stream(self.save_stream):
# MLA: kv_layer is tuple[tensor, tensor] means (rope, nope).
# non-MLA: kv_layer is list[tensor], typically means [k, v].
if self.use_mla:
start, step = self.tp_rank, self.tp_world_size
else:
start, step = 0, 1
for i in range(start, len(save_block_mapping), step):
gpu_block_id, cpu_block_id = save_block_mapping[i]
for cpu_kv_caches, gpu_kv_caches in zip(
self.cpu_kv_caches, self.gpu_kv_caches.values()):
for cpu_layer_part, gpu_layer_part in zip(
cpu_kv_caches, gpu_kv_caches):
cpu_layer_part[cpu_block_id].copy_(
gpu_layer_part[gpu_block_id],
non_blocking=True)
self.save_stream.synchronize()
self.save_output_queue.put(req_id)
save_block_mapping.clear()
# copied and modified from vllm_ascend/worker/model_runner_v1.py
def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
"""
Generates the KVCacheSpec by parsing the kv cache format from each
Attention module in the static forward context.
Returns:
KVCacheSpec: A dictionary mapping layer names to their KV cache
format. Layers that do not need KV cache are not included.
"""
if has_ec_transfer() and get_ec_transfer().is_producer:
return {}
block_size = vllm_config.cache_config.block_size
use_mla = vllm_config.model_config.use_mla
use_sparse = hasattr(vllm_config.model_config.hf_config, "index_topk")
if vllm_config.cache_config.cache_dtype == "auto":
kv_cache_dtype = vllm_config.model_config.dtype
else:
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
vllm_config.cache_config.cache_dtype]
kv_cache_spec: dict[str, KVCacheSpec] = {}
attn_layers = get_layers_from_vllm_config(vllm_config, AttentionLayerBase)
for layer_name, attn_module in attn_layers.items():
if isinstance(attn_module, Attention):
# TODO: Support other attention modules, e.g., cross-attention
# TODO(lucas): move the attention specs into the model layers like
# the attention backends
if attn_module.attn_type == AttentionType.DECODER:
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=kv_cache_dtype)
elif attn_module.attn_type in (AttentionType.ENCODER,
AttentionType.ENCODER_ONLY):
# encoder-only attention does not need KV cache.
continue
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
raise NotImplementedError
else:
raise ValueError(
f"Unknown attention type: {attn_module.attn_type}")
elif isinstance(attn_module, MLAAttention):
if use_mla and not use_sparse:
kv_cache_spec[layer_name] = MLAAttentionSpec(
block_size=block_size,
num_kv_heads=1,
head_size=attn_module.head_size,
dtype=kv_cache_dtype,
cache_dtype_str=vllm_config.cache_config.cache_dtype)
else:
# TODO(cmq): This is a hack way to fix deepseek kvcache when
# using DSA. Fix the spec in vLLM is a finnal way.
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=block_size,
num_kv_heads=1,
head_size=attn_module.head_size,
dtype=kv_cache_dtype)
mamba_layers = get_layers_from_vllm_config(vllm_config, MambaBase)
if len(mamba_layers) > 0:
if (vllm_config.speculative_config is not None
and vllm_config.model_config.hf_config.model_type
not in ["qwen3_next"]):
raise NotImplementedError(
"Mamba with speculative decoding is not supported yet.")
if vllm_config.cache_config.enable_prefix_caching:
raise NotImplementedError(
"Prefix caching is not supported for Mamba yet.")
max_model_len = vllm_config.model_config.max_model_len
page_size_padded = (vllm_config.cache_config.mamba_page_size_padded)
# Set block_size to max_model_len, so that mamba model will always
# have only one block in the KV cache.
for layer_name, mamba_module in mamba_layers.items():
kv_cache_spec[layer_name] = MambaSpec(
shapes=mamba_module.get_state_shape(),
dtypes=mamba_module.get_state_dtype(),
block_size=max_model_len,
page_size_padded=page_size_padded,
mamba_type=mamba_module.mamba_type,
num_speculative_blocks=(
vllm_config.speculative_config.num_speculative_tokens
if vllm_config.speculative_config else 0),
)
return kv_cache_spec

View File

@@ -0,0 +1,246 @@
# SPDX-License-Identifier: Apache-2.0
from typing import TYPE_CHECKING, Any, Optional
import torch
from ucm.integration.vllm.ucm_connector import UCMConnector
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
logger = init_logger(__name__)
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorPromMetrics, KVConnectorStats, PromMetric, PromMetricT)
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request
class UCMConnectorV1(KVConnectorBase_V1):
def __init__(
self,
vllm_config: "VllmConfig",
role: KVConnectorRole,
kv_cache_config: "KVCacheConfig",
):
super().__init__(vllm_config=vllm_config,
role=role,
kv_cache_config=kv_cache_config)
assert vllm_config.kv_transfer_config is not None
ImplCls = UCMConnector
self._ucm_engine = ImplCls(vllm_config, role)
# ==============================
# Worker-side methods
# ==============================
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]) -> None:
"""
Initialize with the KV caches. Useful for pre-registering the
KV Caches in the KVConnector (e.g. for NIXL).
Args:
kv_caches: A dictionary mapping layer names to KV cache tensors.
"""
self._ucm_engine.register_kv_caches(kv_caches)
def start_load_kv(self, forward_context: "ForwardContext",
**kwargs: Any) -> None:
"""
Start loading the KV cache from the connector to vLLM's paged
KV buffer. This is called from the forward context before the
forward pass to enable async loading during model execution.
Args:
forward_context (ForwardContext): the forward context.
**kwargs: additional arguments for the load operation
Note:
The number of elements in kv_caches and layer_names should be
the same.
"""
self._ucm_engine.start_load_kv(forward_context, **kwargs)
def wait_for_layer_load(self, layer_name: str) -> None:
"""
Block until the KV for a specific layer is loaded into vLLM's
paged buffer. This is called from within attention layer to ensure
async copying from start_load_kv is complete.
This interface will be useful for layer-by-layer pipelining.
Args:
layer_name: the name of that layer
"""
self._ucm_engine.wait_for_layer_load(layer_name)
def save_kv_layer(
self,
layer_name: str,
kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata",
**kwargs: Any,
) -> None:
"""
Start saving the a layer of KV cache from vLLM's paged buffer
to the connector. This is called from within attention layer to
enable async copying during execution.
Args:
layer_name (str): the name of the layer.
kv_layer (torch.Tensor): the paged KV buffer of the current
layer in vLLM.
attn_metadata (AttentionMetadata): the attention metadata.
**kwargs: additional arguments for the save operation.
"""
self._ucm_engine.save_kv_layer(layer_name, kv_layer, attn_metadata,
**kwargs)
def wait_for_save(self) -> None:
"""
Block until all the save operations is done. This is called
as the forward context exits to ensure that the async saving
from save_kv_layer is complete before finishing the forward.
This prevents overwrites of paged KV buffer before saving done.
"""
self._ucm_engine.wait_for_save()
def clear_connector_metadata(self) -> None:
"""Clear the connector metadata.
This function should be called by the model runner every time
after the model execution.
"""
self._ucm_engine.clear_connector_metadata()
def bind_connector_metadata(
self, connector_metadata: KVConnectorMetadata) -> None:
"""Set the connector metadata from the scheduler.
This function should be called by the model runner every time
before the model execution. The metadata will be used for runtime
KV cache loading and saving.
Args:
connector_metadata (dict): the connector metadata.
"""
self._ucm_engine.bind_connector_metadata(connector_metadata)
def get_block_ids_with_load_errors(self) -> set[int]:
"""
Get the set of block IDs that failed to load.
Returns:
Set of block IDs that encountered load errors.
Empty set if no load errors occurred.
"""
return self._ucm_engine.get_block_ids_with_load_errors()
# ==============================
# Scheduler-side methods
# ==============================
def get_num_new_matched_tokens(
self,
request: "Request",
num_computed_tokens: int,
) -> tuple[int | None, bool]:
"""
Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
"""
return self._ucm_engine.get_num_new_matched_tokens(
request, num_computed_tokens)
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int) -> None:
"""
Update KVConnector state after block allocation.
"""
self._ucm_engine.update_state_after_alloc(request, blocks,
num_external_tokens)
def build_connector_meta(
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
"""
Build the connector metadata for this step.
This function should NOT modify fields in the scheduler_output.
Also, calling this function will reset the state of the connector.
Args:
scheduler_output (SchedulerOutput): the scheduler output object.
"""
return self._ucm_engine.build_connector_meta(scheduler_output)
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, dict[str, Any] | None]:
"""
Called when a request has finished, before its blocks are freed.
Returns:
True if the request is being saved/sent asynchronously and blocks
should not be freed until the request_id is returned from
get_finished().
Optional KVTransferParams to be included in the request outputs
returned by the engine.
"""
return self._ucm_engine.request_finished(request, block_ids)
# ==============================
# Metrics & Stats
# ==============================
@classmethod
def build_kv_connector_stats(
cls,
data: dict[str, Any] | None = None
) -> Optional["KVConnectorStats"]:
"""
KVConnectorStats resolution method. This method allows dynamically
registered connectors to return their own KVConnectorStats object,
which can implement custom aggregation logic on the data dict.
"""
return UCMConnector.build_kv_connector_stats(data)
@classmethod
def build_prom_metrics(
cls,
vllm_config: "VllmConfig",
metric_types: dict[type["PromMetric"], type["PromMetricT"]],
labelnames: list[str],
per_engine_labelvalues: dict[int, list[object]],
) -> Optional["KVConnectorPromMetrics"]:
"""
Create a KVConnectorPromMetrics subclass which should register
per-connector Prometheus metrics and implement observe() to
expose connector transfer stats via Prometheus.
This implementation forwards the call to the underlying
UCMConnector engine.
"""
return UCMConnector.build_prom_metrics(
vllm_config,
metric_types,
labelnames,
per_engine_labelvalues,
)

View File

@@ -0,0 +1,53 @@
import ipaddress
import threading
from typing import Optional
from mooncake.engine import TransferEngine # type: ignore
class GlobalTE():
def __init__(self):
self.transfer_engine = None
self.is_register_buffer: bool = False
self.transfer_engine_lock = threading.Lock()
self.register_buffer_lock = threading.Lock()
def get_transfer_engine(self, hostname: str, device_name: Optional[str]):
try:
ip = ipaddress.ip_address(hostname)
if isinstance(ip, ipaddress.IPv6Address):
raise RuntimeError(
"The backend of mooncake's Ascend Direct Xfer Library currently does not support IPv6."
)
except ValueError:
pass
if self.transfer_engine is None:
with self.transfer_engine_lock:
# Double-Checked Locking
if self.transfer_engine is None:
if TransferEngine is None:
raise RuntimeError("mooncake is not available")
self.transfer_engine = TransferEngine()
device_name = device_name if device_name is not None else ""
ret_value = self.transfer_engine.initialize(
hostname, "P2PHANDSHAKE", "ascend", device_name)
if ret_value != 0:
raise RuntimeError(
f"TransferEngine initialization failed with ret_value: {ret_value}"
)
return self.transfer_engine
def register_buffer(self, ptrs: list[int], sizes: list[int]):
with self.register_buffer_lock:
assert self.transfer_engine is not None, "Transfer engine must be initialized"
if self.is_register_buffer:
return
for ptr, size in zip(ptrs, sizes):
ret_value = self.transfer_engine.register_memory(ptr, size)
if ret_value != 0:
raise RuntimeError("Mooncake memory registration failed.")
self.is_register_buffer = True
global_te = GlobalTE()

View File

@@ -0,0 +1,61 @@
import os
import torch
import torch.distributed as dist
from vllm_ascend.distributed.parallel_state import get_p_tp_group
def kv_alltoall_and_rearrange(pd_tp_ratio: int, key: torch.Tensor,
value: torch.TensorType):
if pd_tp_ratio <= 1:
return None, None
elif key is None or value is None:
raise ValueError("key or value is None")
k_output = alltoall_and_rearrange(pd_tp_ratio, key)
v_output = alltoall_and_rearrange(pd_tp_ratio, value)
return k_output, v_output
def alltoall_and_rearrange(tp_ratio: int, input_tensor: torch.Tensor):
num_kv_heads = input_tensor.size(1)
output_tensor = torch.zeros_like(input_tensor)
dist.all_to_all_single(output_tensor,
input_tensor,
group=get_p_tp_group().device_group)
input_tensor = 0
result = rearrange_output(output_tensor, tp_ratio, num_kv_heads)
output_tensor = 0
return result
def rearrange_output(base_output: torch.Tensor, cut_num: int,
num_kv_heads: int):
size_0 = base_output.size(0)
if size_0 % cut_num != 0:
raise ValueError(
f"The size of dim 0 [{size_0}] must be divisible by the cut_num [{cut_num}]"
)
chunk_size = size_0 // cut_num
reshaped = base_output.view(cut_num, chunk_size, -1)
transposed = reshaped.transpose(0, 1)
return transposed.contiguous().view(size_0, num_kv_heads, -1)
def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor:
data_ptr = tensor.data_ptr()
aligned_addr = (data_ptr + alignment - 1) // alignment * alignment
offset = (aligned_addr - data_ptr) // tensor.element_size()
return tensor[int(offset):]
def get_transfer_timeout_value():
ascend_transfer_timeout = os.getenv("ASCEND_TRANSFER_TIMEOUT", "")
if len(ascend_transfer_timeout) > 0:
return int(ascend_transfer_timeout)
hccl_rdma_timeout = int(os.getenv('HCCL_RDMA_TIMEOUT',
'20')) # type: ignore
hccl_rdma_retry_cnt = int(os.getenv('HCCL_RDMA_RETRY_CNT',
'7')) # type: ignore
return int((4.096 * (2**hccl_rdma_timeout)) * hccl_rdma_retry_cnt // 1000 +
3000)