[Refactor]Refactor of vllm_ascend/distributed module (#5719)
### What this PR does / why we need it?
Based on the RFC:https://github.com/vllm-project/vllm-ascend/issues/5604
This PR is a refactoring of vllm_ascend/distributed, moving all
kv_transfer realtaed codes into a dedicated folder, which has already
been done in vLLM
### Does this PR introduce _any_ user-facing change?
NA
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
---------
Signed-off-by: lty <linhebiwen@gmail.com>
This commit is contained in:
45
vllm_ascend/distributed/kv_transfer/__init__.py
Normal file
45
vllm_ascend/distributed/kv_transfer/__init__.py
Normal file
@@ -0,0 +1,45 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import \
|
||||
KVConnectorFactory
|
||||
|
||||
|
||||
def register_connector():
|
||||
KVConnectorFactory.register_connector(
|
||||
"MooncakeConnectorV1",
|
||||
"vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector",
|
||||
"MooncakeConnector")
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"MooncakeConnectorStoreV1",
|
||||
"vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.ascend_store_connector",
|
||||
"AscendStoreConnector")
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"AscendStoreConnector",
|
||||
"vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.ascend_store_connector",
|
||||
"AscendStoreConnector")
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"MooncakeLayerwiseConnector",
|
||||
"vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector",
|
||||
"MooncakeLayerwiseConnector")
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"UCMConnector", "vllm_ascend.distributed.kv_transfer.ucm_connector",
|
||||
"UCMConnectorV1")
|
||||
1834
vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_connector.py
Normal file
1834
vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_connector.py
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,184 @@
|
||||
import threading
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import zmq
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.logger import logger
|
||||
from vllm.utils.network_utils import make_zmq_socket
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
from vllm.v1.serial_utils import MsgpackDecoder
|
||||
|
||||
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.pool_scheduler import (
|
||||
KVPoolScheduler, get_zmq_rpc_path_lookup)
|
||||
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.pool_worker import \
|
||||
KVPoolWorker
|
||||
|
||||
|
||||
class AscendStoreConnector(KVConnectorBase_V1):
|
||||
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: Optional[KVCacheConfig] = None):
|
||||
super().__init__(vllm_config=vllm_config,
|
||||
role=role,
|
||||
kv_cache_config=kv_cache_config)
|
||||
self.kv_role = vllm_config.kv_transfer_config.kv_role
|
||||
|
||||
self.use_layerwise = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"use_layerwise", False)
|
||||
self.consumer_is_to_put = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"consumer_is_to_put", False)
|
||||
|
||||
connector_name = vllm_config.kv_transfer_config.kv_connector
|
||||
if connector_name == "MooncakeConnectorStoreV1":
|
||||
logger.warning(
|
||||
"It is recommended to use the AscendStoreConnector, as the MoonCakeStoreConnector will be removed in the future."
|
||||
)
|
||||
|
||||
self.kv_caches: dict[str, torch.Tensor] = {}
|
||||
|
||||
self.sended_but_unfinished_reqs: set[str] = set()
|
||||
|
||||
if role == KVConnectorRole.SCHEDULER:
|
||||
self.connector_scheduler = KVPoolScheduler(vllm_config,
|
||||
self.use_layerwise)
|
||||
else:
|
||||
self.connector_worker = KVPoolWorker(
|
||||
vllm_config,
|
||||
self.use_layerwise,
|
||||
)
|
||||
|
||||
assert self.connector_worker is not None
|
||||
if vllm_config.parallel_config.rank == 0:
|
||||
self.lookup_server = LookupKeyServer(self.connector_worker,
|
||||
vllm_config,
|
||||
self.use_layerwise)
|
||||
|
||||
############################################################
|
||||
# Scheduler Side Methods
|
||||
############################################################
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request",
|
||||
num_computed_tokens: int) -> tuple[int, bool]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens)
|
||||
|
||||
def update_state_after_alloc(self, request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int):
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.update_state_after_alloc(
|
||||
request, blocks, num_external_tokens)
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> KVConnectorMetadata:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.build_connector_meta(scheduler_output)
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.request_finished(request, block_ids)
|
||||
|
||||
############################################################
|
||||
# Worker Side Methods
|
||||
############################################################
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
assert self.connector_worker is not None
|
||||
self.connector_worker.register_kv_caches(kv_caches)
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext",
|
||||
**kwargs) -> None:
|
||||
assert self.connector_worker is not None
|
||||
self.connector_worker.start_load_kv(self._get_connector_metadata())
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
if not self.use_layerwise:
|
||||
return
|
||||
self.connector_worker.wait_for_layer_load()
|
||||
|
||||
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
||||
attn_metadata: "AttentionMetadata", **kwargs) -> None:
|
||||
if not self.use_layerwise:
|
||||
return
|
||||
|
||||
if self.kv_role == "kv_consumer":
|
||||
# Don't do save if the role is kv_consumer
|
||||
return
|
||||
self.connector_worker.save_kv_layer(self._get_connector_metadata())
|
||||
|
||||
def wait_for_save(self):
|
||||
if self.kv_role == "kv_consumer" and not self.consumer_is_to_put:
|
||||
# Don't do save if the role is kv_consumer
|
||||
return
|
||||
|
||||
if self.use_layerwise:
|
||||
return
|
||||
|
||||
self.connector_worker.wait_for_save(self._get_connector_metadata())
|
||||
|
||||
def get_finished(self,
|
||||
finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
|
||||
"""Get the finished recving and sending requests."""
|
||||
assert self.connector_worker is not None
|
||||
done_sending, done_recving = self.connector_worker.get_finished(
|
||||
finished_req_ids)
|
||||
return done_sending, done_recving
|
||||
|
||||
|
||||
class LookupKeyServer:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pool_worker: KVPoolWorker,
|
||||
vllm_config: "VllmConfig",
|
||||
use_layerwise: bool,
|
||||
):
|
||||
self.decoder = MsgpackDecoder()
|
||||
self.decoder_tensor = MsgpackDecoder(torch.Tensor)
|
||||
self.ctx = zmq.Context() # type: ignore[attr-defined]
|
||||
socket_path = get_zmq_rpc_path_lookup(vllm_config)
|
||||
self.socket = make_zmq_socket(
|
||||
self.ctx,
|
||||
socket_path,
|
||||
zmq.REP, # type: ignore[attr-defined]
|
||||
bind=True,
|
||||
)
|
||||
|
||||
self.pool_worker = pool_worker
|
||||
self.running = True
|
||||
self.use_layerwise = use_layerwise
|
||||
|
||||
def process_request():
|
||||
while self.running:
|
||||
all_frames = self.socket.recv_multipart(copy=False)
|
||||
token_len = int.from_bytes(all_frames[0], byteorder="big")
|
||||
hash_frames = all_frames[1:]
|
||||
hashes_str = self.decoder.decode(hash_frames)
|
||||
result = self.pool_worker.lookup_scheduler(
|
||||
token_len, hashes_str, self.use_layerwise)
|
||||
response = result.to_bytes(4, "big")
|
||||
self.socket.send(response)
|
||||
|
||||
self.thread = threading.Thread(target=process_request, daemon=True)
|
||||
self.thread.start()
|
||||
|
||||
def close(self):
|
||||
self.socket.close(linger=0)
|
||||
# TODO: close the thread!
|
||||
@@ -0,0 +1 @@
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
|
||||
|
||||
class Backend(ABC):
|
||||
|
||||
def __init__(self, parallel_config: ParallelConfig):
|
||||
pass
|
||||
|
||||
def set_device(self):
|
||||
pass
|
||||
|
||||
def register_buffer(self, ptrs: list[int], lengths: list[int]):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def exists(self, keys: list[str]) -> list[int]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def put(self, keys: list[str], addrs: list[list[int]],
|
||||
sizes: list[list[int]]):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get(self, keys: list[str], addrs: list[list[int]],
|
||||
sizes: list[list[int]]):
|
||||
pass
|
||||
@@ -0,0 +1,96 @@
|
||||
# Standard
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.logger import logger
|
||||
|
||||
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import \
|
||||
Backend
|
||||
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
|
||||
|
||||
|
||||
class MmcDirect(Enum):
|
||||
COPY_L2G = 0
|
||||
COPY_G2L = 1
|
||||
COPY_G2H = 2
|
||||
COPY_H2G = 3
|
||||
|
||||
|
||||
class MemcacheBackend(Backend):
|
||||
|
||||
def __init__(self, parallel_config: ParallelConfig):
|
||||
try:
|
||||
from memcache_hybrid import DistributedObjectStore # type: ignore
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Please install memcache by following the instructions at "
|
||||
"https://gitee.com/ascend/memfabric_hybrid " # noqa: E501
|
||||
"to run vLLM with MemcacheConnector.") from e
|
||||
try:
|
||||
soc_version = get_ascend_device_type()
|
||||
if soc_version in {AscendDeviceType.A2}:
|
||||
import torch
|
||||
from vllm.distributed import get_world_group
|
||||
tmp_tensor = torch.zeros(1, device="npu")
|
||||
output_tensor_list = [
|
||||
torch.empty_like(tmp_tensor)
|
||||
for _ in range(torch.distributed.get_world_size())
|
||||
]
|
||||
torch.distributed.all_gather(
|
||||
output_tensor_list,
|
||||
tmp_tensor,
|
||||
group=get_world_group().device_group)
|
||||
self.rank = parallel_config.rank
|
||||
self.store = DistributedObjectStore()
|
||||
res = self.store.init(self.rank)
|
||||
assert res == 0
|
||||
else:
|
||||
self.rank = parallel_config.rank
|
||||
self.store = DistributedObjectStore()
|
||||
res = self.store.init(self.rank)
|
||||
assert res == 0
|
||||
except ValueError as e:
|
||||
logger.error("Configuration loading failed: %s", e)
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"An error occurred while loading the configuration: %s", exc)
|
||||
raise
|
||||
|
||||
def set_device(self):
|
||||
device = torch.device(f"npu:{self.rank}")
|
||||
torch.npu.set_device(device)
|
||||
|
||||
def register_buffer(self, ptrs: list[int], sizes: list[int]):
|
||||
soc_version = get_ascend_device_type()
|
||||
if soc_version in {AscendDeviceType.A2}:
|
||||
for ptr, size in zip(ptrs, sizes):
|
||||
self.store.register_buffer(ptr, size)
|
||||
else:
|
||||
pass
|
||||
|
||||
def exists(self, keys: list[str]) -> list[int]:
|
||||
return self.store.batch_is_exist(keys)
|
||||
|
||||
def get(self, key: list[str], addr: list[list[int]],
|
||||
size: list[list[int]]):
|
||||
try:
|
||||
res = self.store.batch_get_into_layers(key, addr, size,
|
||||
MmcDirect.COPY_G2L.value)
|
||||
for value in res:
|
||||
if value != 0:
|
||||
logger.error(f"Failed to get key {key},res:{res}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get key {key}. {e}")
|
||||
|
||||
def put(self, key: list[str], addr: list[list[int]],
|
||||
size: list[list[int]]):
|
||||
try:
|
||||
res = self.store.batch_put_from_layers(key, addr, size,
|
||||
MmcDirect.COPY_L2G.value)
|
||||
for value in res:
|
||||
if value != 0:
|
||||
logger.error(f"Failed to get key {key},res:{res}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to put key {key},error:{e}")
|
||||
@@ -0,0 +1,192 @@
|
||||
# Standard
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Union
|
||||
|
||||
# Third Party
|
||||
from mooncake.store import ReplicateConfig # type: ignore
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.logger import logger
|
||||
from vllm.utils.network_utils import get_ip
|
||||
|
||||
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import \
|
||||
Backend
|
||||
from vllm_ascend.distributed.kv_transfer.utils.mooncake_transfer_engine import \
|
||||
global_te
|
||||
|
||||
DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB
|
||||
DEFAULT_LOCAL_BUFFER_SIZE = 1073741824 # 1.0 GiB
|
||||
|
||||
|
||||
class MooncakeBackend(Backend):
|
||||
|
||||
def __init__(self, parallel_config: ParallelConfig):
|
||||
try:
|
||||
from mooncake.store import MooncakeDistributedStore # type: ignore
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Please install mooncake by following the instructions at "
|
||||
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
|
||||
"to run vLLM with MooncakeConnector.") from e
|
||||
self.config = MooncakeStoreConfig.load_from_env()
|
||||
self.store = MooncakeDistributedStore()
|
||||
if self.config.protocol == "ascend":
|
||||
local_hostname = get_ip()
|
||||
transfer_engine = global_te.get_transfer_engine(local_hostname,
|
||||
device_name=None)
|
||||
self.local_seg = local_hostname + ":" + str(
|
||||
transfer_engine.get_rpc_port())
|
||||
ret = self.store.setup(self.local_seg, self.config.metadata_server,
|
||||
self.config.global_segment_size,
|
||||
self.config.local_buffer_size,
|
||||
self.config.protocol,
|
||||
self.config.device_name,
|
||||
self.config.master_server_address,
|
||||
transfer_engine.get_engine())
|
||||
if ret != 0:
|
||||
msg = "Initialize mooncake failed."
|
||||
logger.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
|
||||
def register_buffer(self, ptrs: list[int], lengths: list[int]):
|
||||
global_te.register_buffer(ptrs, lengths)
|
||||
|
||||
def exists(self, keys: list[str]) -> list[int]:
|
||||
return self.store.batch_is_exist(keys)
|
||||
|
||||
def put(self, keys: list[str], addrs: list[list[int]],
|
||||
sizes: list[list[int]]):
|
||||
try:
|
||||
config = ReplicateConfig()
|
||||
config.preferred_segment = self.local_seg
|
||||
config.prefer_alloc_in_same_node = True
|
||||
res = self.store.batch_put_from_multi_buffers(
|
||||
keys, addrs, sizes, config)
|
||||
for value in res:
|
||||
if value < 0:
|
||||
logger.error(f"Failed to put key {keys},res:{res}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to put key {keys},error:{e}")
|
||||
|
||||
def get(self, keys: list[str], addrs: list[list[int]],
|
||||
sizes: list[list[int]]):
|
||||
try:
|
||||
res = self.store.batch_get_into_multi_buffers(
|
||||
keys, addrs, sizes, True)
|
||||
for value in res:
|
||||
if value < 0:
|
||||
logger.error(f"Failed to get key {keys}, res:{res}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get key {keys}, error:{e}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class MooncakeStoreConfig:
|
||||
metadata_server: str
|
||||
global_segment_size: Union[int, str]
|
||||
local_buffer_size: int
|
||||
protocol: str
|
||||
device_name: str
|
||||
master_server_address: str
|
||||
|
||||
@staticmethod
|
||||
def from_file(file_path: str) -> "MooncakeStoreConfig":
|
||||
with open(file_path) as file:
|
||||
config = json.load(file)
|
||||
return MooncakeStoreConfig(
|
||||
metadata_server=config.get("metadata_server"),
|
||||
global_segment_size=_parse_global_segment_size(
|
||||
config.get("global_segment_size",
|
||||
DEFAULT_GLOBAL_SEGMENT_SIZE)),
|
||||
local_buffer_size=_parse_global_segment_size(
|
||||
config.get("local_buffer_size", DEFAULT_LOCAL_BUFFER_SIZE)),
|
||||
protocol=config.get("protocol", "ascend"),
|
||||
device_name=config.get("device_name", ""),
|
||||
master_server_address=config.get("master_server_address"))
|
||||
|
||||
@staticmethod
|
||||
def load_from_env() -> "MooncakeStoreConfig":
|
||||
config_path = os.getenv("MOONCAKE_CONFIG_PATH")
|
||||
if not config_path:
|
||||
raise ValueError(
|
||||
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set.")
|
||||
return MooncakeStoreConfig.from_file(config_path)
|
||||
|
||||
|
||||
def _parse_global_segment_size(value) -> int:
|
||||
"""
|
||||
Parse storage size strings with support for units: GB, MB, KB, B
|
||||
|
||||
Args:
|
||||
value: Input value (int, str, or other convertible types)
|
||||
|
||||
Returns:
|
||||
int: Size in bytes
|
||||
|
||||
Raises:
|
||||
ValueError: For invalid format, missing number, or negative values
|
||||
TypeError: For unsupported input types
|
||||
"""
|
||||
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
elif not isinstance(value, str):
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError) as e:
|
||||
raise TypeError(
|
||||
f"Unsupported type for global_segment_size: {type(value)}"
|
||||
) from e
|
||||
|
||||
cleaned_input = value.strip().lower()
|
||||
if not cleaned_input:
|
||||
raise ValueError("global segment size cannot be empty.")
|
||||
|
||||
UNIT_MULTIPLIERS = {
|
||||
'gb': 1024**3, # 1 GB = 1024^3 bytes
|
||||
'mb': 1024**2, # 1 MB = 1024^2 bytes
|
||||
'kb': 1024, # 1 KB = 1024 bytes
|
||||
'b': 1 # 1 B = 1 byte
|
||||
}
|
||||
pattern = r'^\s*([\d.]+)\s*(gb|mb|kb|b)?\s*$'
|
||||
match = re.match(pattern, cleaned_input)
|
||||
|
||||
if not match:
|
||||
raise ValueError(f"Invalid format: '{value}'")
|
||||
|
||||
number_str = match.group(1)
|
||||
unit = match.group(2) or 'b'
|
||||
|
||||
multiplier = UNIT_MULTIPLIERS[unit]
|
||||
return _convert_to_bytes(number_str, multiplier, value)
|
||||
|
||||
|
||||
def _convert_to_bytes(number_str: str, multiplier: int,
|
||||
original_input: str) -> int:
|
||||
"""
|
||||
Convert numeric string to byte count
|
||||
|
||||
Args:
|
||||
number_str: Numeric portion of input
|
||||
multiplier: Unit conversion factor
|
||||
original_input: Original input string (for error messages)
|
||||
|
||||
Returns:
|
||||
int: Byte count
|
||||
|
||||
Raises:
|
||||
ValueError: For invalid numbers or negative results
|
||||
"""
|
||||
try:
|
||||
numeric_value = float(number_str)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Invalid numeric value '{number_str}' in: '{original_input}'")
|
||||
# Calculate byte count
|
||||
try:
|
||||
byte_count = int(numeric_value * multiplier)
|
||||
except OverflowError:
|
||||
raise ValueError(f"Storage size too large: '{original_input}'")
|
||||
return byte_count
|
||||
@@ -0,0 +1,404 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import \
|
||||
KVConnectorMetadata
|
||||
from vllm.logger import logger
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash
|
||||
from vllm.v1.core.sched.output import NewRequestData
|
||||
|
||||
|
||||
#Parameters related to the key
|
||||
@dataclass
|
||||
class KeyMetadata:
|
||||
"""name of the LLM model"""
|
||||
|
||||
model_name: str
|
||||
""" worker id when running under a distributed setting """
|
||||
head_or_tp_rank: int
|
||||
""" Initialize the current prefill context model parallel rank """
|
||||
pcp_rank: int
|
||||
""" Initialize the current decode context model parallel rank """
|
||||
dcp_rank: int
|
||||
""" Initialize the current pipeline parallel rank """
|
||||
pp_rank: int
|
||||
|
||||
|
||||
@dataclass(order=True)
|
||||
class PoolKey:
|
||||
key_metadata: KeyMetadata
|
||||
chunk_hash: str
|
||||
|
||||
def __hash__(self):
|
||||
return hash((
|
||||
self.key_metadata.model_name,
|
||||
self.key_metadata.head_or_tp_rank,
|
||||
self.key_metadata.pcp_rank,
|
||||
self.key_metadata.dcp_rank,
|
||||
self.key_metadata.pp_rank,
|
||||
self.chunk_hash,
|
||||
))
|
||||
|
||||
def to_string(self):
|
||||
return (
|
||||
f"{self.key_metadata.model_name}"
|
||||
f"@pcp{self.key_metadata.pcp_rank}@dcp{self.key_metadata.dcp_rank}"
|
||||
f"@head_or_tp_rank:{self.key_metadata.head_or_tp_rank}"
|
||||
f"@pp_rank:{self.key_metadata.pp_rank}@{self.chunk_hash}")
|
||||
|
||||
def split_layers(self, num_layers: int) -> List["LayerPoolKey"]:
|
||||
"""Split the key into multiple keys for each layer"""
|
||||
keys = []
|
||||
for layer_id in range(num_layers):
|
||||
keys.append(
|
||||
LayerPoolKey(
|
||||
self.key_metadata,
|
||||
self.chunk_hash,
|
||||
layer_id,
|
||||
))
|
||||
return keys
|
||||
|
||||
|
||||
@dataclass(order=True)
|
||||
class LayerPoolKey(PoolKey):
|
||||
"""A key for the layer cache engine"""
|
||||
|
||||
layer_id: int
|
||||
|
||||
def __hash__(self):
|
||||
return hash((
|
||||
self.key_metadata.model_name,
|
||||
self.key_metadata.head_or_tp_rank,
|
||||
self.key_metadata.pcp_rank,
|
||||
self.key_metadata.dcp_rank,
|
||||
self.chunk_hash,
|
||||
self.layer_id,
|
||||
))
|
||||
|
||||
def to_string(self):
|
||||
return (
|
||||
f"{self.key_metadata.model_name}"
|
||||
f"@pcp{self.key_metadata.pcp_rank}@dcp{self.key_metadata.dcp_rank}"
|
||||
f"@head_or_tp_rank:{self.key_metadata.head_or_tp_rank}@{self.chunk_hash}@{self.layer_id}"
|
||||
)
|
||||
|
||||
|
||||
class ChunkedTokenDatabase():
|
||||
|
||||
def __init__(self, metadata: KeyMetadata, block_size: int, use_mla: bool,
|
||||
partitions: Optional[List[int]]):
|
||||
self.metadata = metadata
|
||||
self.block_size = block_size
|
||||
self.use_mla = use_mla
|
||||
self.kv_caches_base_addr: list[int] = []
|
||||
self.block_len: list[int] = []
|
||||
self.partitions = partitions
|
||||
|
||||
def _make_key_by_hash(self,
|
||||
chunk_hash: str,
|
||||
layer_id: Optional[int] = None):
|
||||
assert self.metadata is not None
|
||||
return PoolKey(
|
||||
self.metadata,
|
||||
chunk_hash,
|
||||
)
|
||||
|
||||
def set_kv_caches_base_addr(self, kv_caches_base_addr: list[int]):
|
||||
self.kv_caches_base_addr = kv_caches_base_addr
|
||||
|
||||
def set_block_len(self, block_len: list[int]):
|
||||
self.block_len = block_len
|
||||
|
||||
def prepare_value(self, start: int, end: int, block_ids: list[int]):
|
||||
addr_list = []
|
||||
size_list = []
|
||||
block_id = block_ids[start // self.block_size]
|
||||
for index, base_addr in enumerate(self.kv_caches_base_addr):
|
||||
block_len = (self.block_len[index % 2]
|
||||
if self.use_mla else self.block_len[0])
|
||||
|
||||
addr = base_addr + block_id * block_len
|
||||
length = int(block_len / self.block_size * (end - start))
|
||||
addr_list.append(addr)
|
||||
size_list.append(length)
|
||||
return addr_list, size_list, block_id
|
||||
|
||||
def prepare_value_layer(self, start: int, end: int, block_ids: list[int],
|
||||
layer_id: int):
|
||||
block_id = block_ids[start // self.block_size]
|
||||
if self.use_mla:
|
||||
addr_k = self.kv_caches_base_addr[layer_id *
|
||||
2] + block_id * self.block_len[0]
|
||||
addr_v = self.kv_caches_base_addr[layer_id * 2 +
|
||||
1] + block_id * self.block_len[1]
|
||||
length_k = int(self.block_len[0] / self.block_size * (end - start))
|
||||
length_v = int(self.block_len[1] / self.block_size * (end - start))
|
||||
size_list = [length_k, length_v]
|
||||
else:
|
||||
addr_k = self.kv_caches_base_addr[layer_id *
|
||||
2] + block_id * self.block_len[0]
|
||||
addr_v = self.kv_caches_base_addr[layer_id * 2 +
|
||||
1] + block_id * self.block_len[0]
|
||||
length = int(self.block_len[0] / self.block_size * (end - start))
|
||||
size_list = [length, length]
|
||||
addr_list = [addr_k, addr_v]
|
||||
return addr_list, size_list
|
||||
|
||||
def process_tokens(
|
||||
self,
|
||||
token_len: int,
|
||||
block_hashes: Union[list[BlockHash], list[str]],
|
||||
mask_num: int = 0,
|
||||
) -> Iterable[Tuple[int, int, PoolKey]]:
|
||||
"""Process the tokens and return the corresponding cache engine keys.
|
||||
|
||||
:param Union[torch.Tensor, List[int]] tokens: The tokens to process.
|
||||
|
||||
:param Optional[torch.Tensor] mask: The mask for the tokens. Should
|
||||
have the same length as tokens. And the mask should ALWAYS be like
|
||||
FFFFFTTTTTTT, where True means the tokens needs to be matched,
|
||||
and the Falses will ALWAYS be at the PREFIX of the tensor.
|
||||
|
||||
:param bool make_key: Whether to make the cache engine key or not.
|
||||
If False, the hash value will be returned instead.
|
||||
|
||||
:returns: A iterable of tuples with three elements. The first element
|
||||
is the start index of the tokens for the key. The second element
|
||||
is the end index of the tokens for the key. The third element is
|
||||
the cache engine key (or hash) for the tokens.
|
||||
|
||||
:raises: ValueError if the number of Falses in the mask is not a
|
||||
multiple of the chunk size.
|
||||
"""
|
||||
if not block_hashes:
|
||||
return
|
||||
if not isinstance(block_hashes[0], str):
|
||||
block_hashes = [
|
||||
h.hex() # type: ignore[union-attr]
|
||||
for h in block_hashes
|
||||
]
|
||||
start_idx = 0
|
||||
for chunk_id, hash_val in enumerate(block_hashes):
|
||||
start_idx = chunk_id * self.block_size
|
||||
if start_idx >= token_len:
|
||||
break
|
||||
end_idx = min(start_idx + self.block_size, token_len)
|
||||
if start_idx < mask_num:
|
||||
continue
|
||||
else:
|
||||
yield start_idx, end_idx, self._make_key_by_hash(hash_val)
|
||||
|
||||
def decode_adaptor_prefill_pp(self, key, addr, size):
|
||||
if self.partitions is None or len(self.partitions) == 1:
|
||||
return key, addr, size
|
||||
|
||||
new_key = []
|
||||
new_addr = []
|
||||
new_size = []
|
||||
|
||||
for i, (addr_list, size_list) in enumerate(zip(addr, size)):
|
||||
start = 0
|
||||
for j, part in enumerate(self.partitions):
|
||||
# part * 2 because addr and size contain both k and v
|
||||
end = len(addr_list) if j == len(
|
||||
self.partitions) - 1 else start + part * 2
|
||||
new_str = key[i].replace( # type: ignore[attr-defined]
|
||||
"@pp_rank:0", f"@pp_rank:{j}", 1)
|
||||
new_key.append(new_str)
|
||||
new_addr.append(addr_list[start:end])
|
||||
new_size.append(size_list[start:end])
|
||||
start = end
|
||||
return new_key, new_addr, new_size
|
||||
|
||||
|
||||
#Parameters related to the connector metadata
|
||||
@dataclass
|
||||
class LoadSpec:
|
||||
# Number of tokens cached in vLLM
|
||||
vllm_cached_tokens: int
|
||||
# Number of tokens that are cached in kvpool
|
||||
kvpool_cached_tokens: int
|
||||
# Whether the scheduler allow us to load the tokens
|
||||
can_load: bool
|
||||
|
||||
token_len: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestTracker:
|
||||
# Request id
|
||||
req_id: str
|
||||
|
||||
# The token ids that has been scheduled so far
|
||||
token_len: int
|
||||
|
||||
# The block ids that has been allocated so far
|
||||
# NOTE: allocated blocks could be more than the number of tokens
|
||||
# FIXME: need to check whether the block ids will be changed after
|
||||
# preemption
|
||||
allocated_block_ids: list[int]
|
||||
|
||||
# The number of tokens that has been savd
|
||||
num_saved_tokens: int = 0
|
||||
|
||||
@staticmethod
|
||||
def from_new_request(
|
||||
new_request: "NewRequestData",
|
||||
num_tokens_to_compute: int,
|
||||
) -> "RequestTracker":
|
||||
"""Create the request tracker from a new request.
|
||||
|
||||
Args:
|
||||
new_request (NewRequestData): the new request data.
|
||||
num_tokens_to_compute (int): the number of tokens that will
|
||||
be 'computed', including the `num_computed_tokens` (vLLM's
|
||||
local cache hit) and new tokens that will be scheduled.
|
||||
|
||||
"""
|
||||
unfolded_block_ids = []
|
||||
|
||||
if not isinstance(new_request.block_ids[0], list):
|
||||
unfolded_block_ids = new_request.block_ids.copy()
|
||||
else:
|
||||
unfolded_block_ids = new_request.block_ids[0].copy()
|
||||
|
||||
return RequestTracker(
|
||||
req_id=new_request.req_id,
|
||||
token_len=num_tokens_to_compute,
|
||||
allocated_block_ids=unfolded_block_ids,
|
||||
num_saved_tokens=0,
|
||||
)
|
||||
|
||||
def update(
|
||||
self,
|
||||
new_block_ids: Union[tuple[list[int], ...], list[int]],
|
||||
) -> None:
|
||||
"""Update the request tracker when a running request is
|
||||
scheduled again
|
||||
"""
|
||||
|
||||
if len(new_block_ids) == 0:
|
||||
new_block_ids = []
|
||||
elif isinstance(new_block_ids, tuple):
|
||||
new_block_ids = new_block_ids[0]
|
||||
elif isinstance(new_block_ids, list):
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported new_block_ids type {type(new_block_ids)}")
|
||||
self.allocated_block_ids.extend(new_block_ids)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReqMeta:
|
||||
# Request id
|
||||
req_id: str
|
||||
# Request tokens
|
||||
token_len_chunk: int
|
||||
|
||||
block_ids: list[int]
|
||||
|
||||
block_hashes: list[BlockHash]
|
||||
|
||||
can_save: Optional[bool] = None
|
||||
# load_spec
|
||||
load_spec: Optional[LoadSpec] = None
|
||||
|
||||
is_last_chunk: Optional[bool] = None
|
||||
|
||||
current_event: Optional[torch.npu.Event] = None
|
||||
|
||||
@staticmethod
|
||||
def from_request_tracker(
|
||||
tracker: RequestTracker,
|
||||
block_size: int,
|
||||
load_spec: Optional[LoadSpec] = None,
|
||||
skip_save: Optional[bool] = False,
|
||||
block_hashes: list[BlockHash] = [],
|
||||
is_last_chunk: Optional[bool] = None,
|
||||
discard_partial_chunks: bool = True,
|
||||
) -> Optional["ReqMeta"]:
|
||||
"""Create the request metadata from a request tracker.
|
||||
|
||||
Args:
|
||||
tracker (RequestTracker): the request tracker.
|
||||
block_size (int): the block size in vLLM.
|
||||
load_spec (Optional[LoadSpec]): the load spec for KV cache loading.
|
||||
skip_save (bool): whether to skip the save operation.
|
||||
discard_partial_chunks (bool): whether to discard partial chunks.
|
||||
|
||||
Returns:
|
||||
the request metadata if we need to perform load/save
|
||||
operations, None otherwise.
|
||||
"""
|
||||
input_token_len = tracker.token_len
|
||||
|
||||
# For save operation: do not save if the following condition is met
|
||||
# 1. has already been saved before (num_saved_tokens > 0)
|
||||
# 2. number of unsaved tokens is not reached the chunk boundary
|
||||
chunk_boundary = (cdiv(tracker.num_saved_tokens + 1, block_size) *
|
||||
block_size if discard_partial_chunks else 0)
|
||||
# Calculate number of tokens to save based on discard_partial_chunks
|
||||
# setting
|
||||
num_tokens_to_save = ((input_token_len // block_size * block_size)
|
||||
if discard_partial_chunks else input_token_len)
|
||||
|
||||
skip_save = skip_save or num_tokens_to_save < chunk_boundary
|
||||
if skip_save and load_spec is None:
|
||||
return None
|
||||
|
||||
# If we need to save, update the number of saved tokens
|
||||
if not skip_save:
|
||||
tracker.num_saved_tokens = num_tokens_to_save
|
||||
|
||||
# # For load operation: check whether the request is scheduled to load
|
||||
if load_spec is not None and load_spec.can_load:
|
||||
logger.debug(
|
||||
"Scheduled to load %d tokens for request %s",
|
||||
load_spec.kvpool_cached_tokens,
|
||||
tracker.req_id,
|
||||
)
|
||||
else:
|
||||
# Do not load if not in `can_load` state
|
||||
load_spec = None
|
||||
logger.debug(
|
||||
f"request:{tracker.req_id}, meta save spec:{not skip_save}, meta load spec:{load_spec}"
|
||||
)
|
||||
return ReqMeta(
|
||||
req_id=tracker.req_id,
|
||||
token_len_chunk=num_tokens_to_save,
|
||||
block_ids=tracker.allocated_block_ids,
|
||||
can_save=not skip_save,
|
||||
load_spec=load_spec,
|
||||
block_hashes=block_hashes,
|
||||
is_last_chunk=is_last_chunk,
|
||||
)
|
||||
|
||||
|
||||
class AscendConnectorMetadata(KVConnectorMetadata):
|
||||
|
||||
def __init__(self, unfinished_request_ids):
|
||||
self.requests = []
|
||||
self.unfinished_request_ids = unfinished_request_ids
|
||||
|
||||
def add_request(self, req_meta: ReqMeta) -> None:
|
||||
"""Add a request to the metadata.
|
||||
|
||||
Args:
|
||||
req_meta (ReqMeta): the request metadata.
|
||||
"""
|
||||
self.requests.append(req_meta)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LasyerMultiBlockReqMeta:
|
||||
req_id: str
|
||||
keys: List[LayerPoolKey]
|
||||
starts: List[int]
|
||||
ends: list[int]
|
||||
block_ids: list[int]
|
||||
layer_id: int
|
||||
is_last_chunk: Optional[bool] = True
|
||||
current_event: Optional[torch.npu.Event] = None
|
||||
@@ -0,0 +1,361 @@
|
||||
import queue
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from vllm.logger import logger
|
||||
|
||||
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import \
|
||||
Backend
|
||||
|
||||
# isort: off
|
||||
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import (
|
||||
ChunkedTokenDatabase,
|
||||
LasyerMultiBlockReqMeta,
|
||||
ReqMeta,
|
||||
)
|
||||
# isort: on
|
||||
|
||||
|
||||
class KVTransferThread(threading.Thread):
|
||||
|
||||
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
|
||||
block_size: int, tp_rank: int, dcp_size: int,
|
||||
ready_event: threading.Event, name: str):
|
||||
super().__init__(daemon=True, name=name)
|
||||
self.m_store = m_store
|
||||
self.ready_event = ready_event
|
||||
self.block_size = block_size
|
||||
self.tp_rank = tp_rank
|
||||
self.dcp_size = dcp_size
|
||||
self.token_database = token_database
|
||||
self.done_task_lock = threading.Lock()
|
||||
self.request_queue: queue.Queue[Any] = queue.Queue()
|
||||
# TODO(jianzs): make this configurable
|
||||
self.executor = ThreadPoolExecutor(max_workers=32)
|
||||
self.finished_requests: set[str] = set()
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
request: ReqMeta,
|
||||
) -> torch.Tensor:
|
||||
self.request_queue.put(request)
|
||||
|
||||
def get_and_clear_finished_requests(self) -> set[str]:
|
||||
"""
|
||||
Get and clear the requests that have been completed.
|
||||
Returns:
|
||||
A set of request IDs that have been completed.
|
||||
"""
|
||||
with self.done_task_lock:
|
||||
finished_requests = self.finished_requests.copy()
|
||||
self.finished_requests.clear()
|
||||
return finished_requests
|
||||
|
||||
def set_finished_request(self, req_id):
|
||||
with self.done_task_lock:
|
||||
self.finished_requests.add(req_id)
|
||||
|
||||
def run(self):
|
||||
"""Run the thread to handle KV cache transfer requests."""
|
||||
self.m_store.set_device()
|
||||
self.ready_event.set()
|
||||
while True:
|
||||
try:
|
||||
request_data = self.request_queue.get()
|
||||
if request_data is None:
|
||||
logger.warning("Received a None request!")
|
||||
self.request_queue.task_done()
|
||||
continue
|
||||
self._handle_request(request_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in KVCacheTransferThread: {e}")
|
||||
|
||||
def _handle_request(self, req_meta: Any):
|
||||
pass
|
||||
|
||||
def lookup(
|
||||
self,
|
||||
keys: list[str],
|
||||
) -> int:
|
||||
"""
|
||||
Checks the existence of KV cache of the tokens from the cache engine.
|
||||
:param tokens: the input tokens, with shape [seq_len]
|
||||
:return: An int indicating how many prefix tokens are cached.
|
||||
"""
|
||||
try:
|
||||
res = self.m_store.exists(keys) # type: ignore[assignment]
|
||||
for index, value in enumerate(res): # type: ignore[arg-type]
|
||||
if value != 1:
|
||||
return index
|
||||
# all tokens where found, return the maximal end
|
||||
except Exception as e:
|
||||
logger.error(f"Remote connection failed in contains: {e}")
|
||||
return 0
|
||||
return len(keys)
|
||||
|
||||
|
||||
class KVCacheStoreSendingThread(KVTransferThread):
|
||||
|
||||
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
|
||||
block_size: int, tp_rank: int, dcp_size: int, put_step: int,
|
||||
kv_role: str, ready_event: threading.Event):
|
||||
super().__init__(m_store,
|
||||
token_database,
|
||||
block_size,
|
||||
tp_rank,
|
||||
dcp_size,
|
||||
ready_event,
|
||||
name="KVCacheSendingThread")
|
||||
self.put_step = put_step
|
||||
self.kv_role = kv_role
|
||||
self.stored_requests = defaultdict[str, int](int)
|
||||
|
||||
def add_stored_request(self, req_id: str):
|
||||
with self.done_task_lock:
|
||||
self.stored_requests[req_id] += 1
|
||||
|
||||
def delete_finished_stored_request(self, req_id: str):
|
||||
with self.done_task_lock:
|
||||
if req_id in self.stored_requests:
|
||||
del self.stored_requests[req_id]
|
||||
|
||||
def _handle_request(self, req_meta: ReqMeta):
|
||||
token_len = req_meta.token_len_chunk
|
||||
block_ids = req_meta.block_ids
|
||||
req_id = req_meta.req_id
|
||||
current_event = req_meta.current_event
|
||||
starts = []
|
||||
ends = []
|
||||
keys = []
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
token_len, req_meta.block_hashes):
|
||||
starts.append(start)
|
||||
ends.append(end)
|
||||
keys.append(key.to_string())
|
||||
|
||||
if not self.dcp_size > 1:
|
||||
starts = starts[self.tp_rank % self.put_step::self.put_step]
|
||||
ends = ends[self.tp_rank % self.put_step::self.put_step]
|
||||
keys = keys[self.tp_rank % self.put_step::self.put_step]
|
||||
|
||||
if not keys:
|
||||
with self.done_task_lock:
|
||||
self.stored_requests[req_id] -= 1
|
||||
return
|
||||
|
||||
skip_block_num = self.lookup(keys)
|
||||
|
||||
if skip_block_num == len(keys):
|
||||
with self.done_task_lock:
|
||||
self.stored_requests[req_id] -= 1
|
||||
return
|
||||
|
||||
starts = starts[skip_block_num:]
|
||||
ends = ends[skip_block_num:]
|
||||
keys = keys[skip_block_num:]
|
||||
|
||||
logger.info(
|
||||
"Storing KV cache for %d out of %d blocks "
|
||||
"(skip_block_num=%d) for request %s",
|
||||
len(keys),
|
||||
token_len // self.block_size,
|
||||
skip_block_num,
|
||||
req_id,
|
||||
)
|
||||
|
||||
if keys:
|
||||
"""
|
||||
Note: Due to a bug in ADXL, calling current_event.synchronize() may occasionally hang.
|
||||
This issue will be fixed in CANN version 8.5.rc1.
|
||||
You can manually build the master branch of the project at https://gitcode.com/cann/hixl
|
||||
to resolve this issue before the 8.5.RC1 release.
|
||||
"""
|
||||
addrs = []
|
||||
sizes = []
|
||||
for index, start in enumerate(starts):
|
||||
addr, size, _ = self.token_database.prepare_value(
|
||||
start, ends[index], block_ids)
|
||||
addrs.append(addr)
|
||||
sizes.append(size)
|
||||
|
||||
if self.kv_role == "kv_consumer":
|
||||
keys, addrs, sizes = self.token_database.decode_adaptor_prefill_pp(
|
||||
keys, addrs, sizes)
|
||||
|
||||
if current_event is not None:
|
||||
current_event.synchronize()
|
||||
self.m_store.put(keys, addrs, sizes)
|
||||
|
||||
with self.done_task_lock:
|
||||
self.stored_requests[req_id] -= 1
|
||||
self.request_queue.task_done()
|
||||
|
||||
|
||||
class KVCacheStoreRecvingThread(KVTransferThread):
|
||||
|
||||
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
|
||||
block_size: int, tp_rank: int, dcp_size: int,
|
||||
ready_event: threading.Event):
|
||||
super().__init__(m_store,
|
||||
token_database,
|
||||
block_size,
|
||||
tp_rank,
|
||||
dcp_size,
|
||||
ready_event,
|
||||
name="KVCacheStoreRecvingThread")
|
||||
|
||||
def _handle_request(self, req_meta: ReqMeta):
|
||||
token_len = req_meta.load_spec.token_len # type: ignore[union-attr]
|
||||
req_id = req_meta.req_id
|
||||
mask_num = (
|
||||
req_meta.load_spec.vllm_cached_tokens # type: ignore[union-attr]
|
||||
// self.block_size * self.block_size)
|
||||
addr_list = []
|
||||
size_list = []
|
||||
key_list = []
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
token_len, req_meta.block_hashes, mask_num):
|
||||
addr, size, _ = self.token_database.prepare_value(
|
||||
start, end, req_meta.block_ids)
|
||||
key_list.append(key.to_string())
|
||||
addr_list.append(addr)
|
||||
size_list.append(size)
|
||||
key_list_c = key_list[self.tp_rank %
|
||||
len(key_list):] + key_list[:self.tp_rank %
|
||||
len(key_list)]
|
||||
addr_list_c = addr_list[self.tp_rank %
|
||||
len(addr_list):] + addr_list[:self.tp_rank %
|
||||
len(addr_list)]
|
||||
size_list_c = size_list[self.tp_rank %
|
||||
len(size_list):] + size_list[:self.tp_rank %
|
||||
len(size_list)]
|
||||
self.m_store.get(key_list_c, addr_list_c, size_list_c)
|
||||
self.set_finished_request(req_id)
|
||||
self.request_queue.task_done()
|
||||
|
||||
|
||||
class KVCacheStoreLayerSendingThread(KVTransferThread):
|
||||
|
||||
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
|
||||
block_size: int, tp_rank: int, dcp_size: int, put_step: int,
|
||||
ready_event: threading.Event, num_layers: int):
|
||||
super().__init__(m_store,
|
||||
token_database,
|
||||
block_size,
|
||||
tp_rank,
|
||||
dcp_size,
|
||||
ready_event,
|
||||
name="KVCacheStoreLayerSendingThread")
|
||||
self.final_layer_id = num_layers - 1
|
||||
self.put_step = put_step
|
||||
|
||||
def add_request( # type: ignore[override]
|
||||
self, req_meta: ReqMeta) -> torch.Tensor:
|
||||
self.request_queue.put(req_meta)
|
||||
|
||||
def _handle_request( # type: ignore[override]
|
||||
self, req_meta: LasyerMultiBlockReqMeta):
|
||||
starts = req_meta.starts
|
||||
ends = req_meta.ends
|
||||
keys = req_meta.keys
|
||||
layer_id = req_meta.layer_id
|
||||
current_event = req_meta.current_event
|
||||
total_block = len(keys)
|
||||
is_last_chunk = req_meta.is_last_chunk
|
||||
if not self.dcp_size > 1:
|
||||
starts = starts[self.tp_rank % self.put_step::self.put_step]
|
||||
ends = ends[self.tp_rank % self.put_step::self.put_step]
|
||||
keys = keys[self.tp_rank % self.put_step::self.put_step]
|
||||
|
||||
if not keys:
|
||||
if is_last_chunk:
|
||||
self.set_finished_request(req_meta.req_id)
|
||||
return
|
||||
|
||||
key_list = []
|
||||
for key in keys:
|
||||
key_list.append(key.to_string())
|
||||
|
||||
skip_block_num = self.lookup(key_list)
|
||||
|
||||
if skip_block_num == len(key_list):
|
||||
if is_last_chunk and layer_id == self.final_layer_id:
|
||||
self.set_finished_request(req_meta.req_id)
|
||||
return
|
||||
|
||||
starts = starts[skip_block_num:]
|
||||
ends = ends[skip_block_num:]
|
||||
key_list = key_list[skip_block_num:]
|
||||
|
||||
addr_list = []
|
||||
size_list = []
|
||||
for index, key in enumerate(key_list):
|
||||
addr, size = self.token_database.prepare_value_layer(
|
||||
starts[index], ends[index], req_meta.block_ids, layer_id)
|
||||
addr_list.append(addr)
|
||||
size_list.append(size)
|
||||
|
||||
if current_event is not None:
|
||||
current_event.synchronize()
|
||||
self.m_store.put(key_list, addr_list, size_list)
|
||||
|
||||
if layer_id == self.final_layer_id and is_last_chunk:
|
||||
self.set_finished_request(req_meta.req_id)
|
||||
self.request_queue.task_done()
|
||||
|
||||
logger.info(
|
||||
"Storing KV cache for %d out of %d blocks "
|
||||
"(skip_block_num=%d) for request %s",
|
||||
len(keys),
|
||||
total_block,
|
||||
skip_block_num,
|
||||
req_meta.req_id,
|
||||
)
|
||||
|
||||
|
||||
class KVCacheStoreLayerRecvingThread(KVTransferThread):
|
||||
|
||||
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
|
||||
block_size: int, tp_rank: int, dcp_size: int,
|
||||
ready_event: threading.Event, get_event: threading.Event):
|
||||
super().__init__(m_store,
|
||||
token_database,
|
||||
block_size,
|
||||
tp_rank,
|
||||
dcp_size,
|
||||
ready_event,
|
||||
name="KVCacheStoreLayerRecvingThread")
|
||||
self.get_event = get_event
|
||||
|
||||
def add_request( # type: ignore[override]
|
||||
self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor:
|
||||
self.request_queue.put(req_meta)
|
||||
|
||||
def _handle_request( # type: ignore[override]
|
||||
self, req_meta: LasyerMultiBlockReqMeta):
|
||||
addr_list = []
|
||||
size_list = []
|
||||
key_list = []
|
||||
for index, key in enumerate(req_meta.keys):
|
||||
addr, size = self.token_database.prepare_value_layer(
|
||||
req_meta.starts[index], req_meta.ends[index],
|
||||
req_meta.block_ids, req_meta.layer_id)
|
||||
key_list.append(key.to_string())
|
||||
addr_list.append(addr)
|
||||
size_list.append(size)
|
||||
key_list_c = key_list[self.tp_rank %
|
||||
len(key_list):] + key_list[:self.tp_rank %
|
||||
len(key_list)]
|
||||
addr_list_c = addr_list[self.tp_rank %
|
||||
len(addr_list):] + addr_list[:self.tp_rank %
|
||||
len(addr_list)]
|
||||
size_list_c = size_list[self.tp_rank %
|
||||
len(size_list):] + size_list[:self.tp_rank %
|
||||
len(size_list)]
|
||||
self.m_store.get(key_list_c, addr_list_c, size_list_c)
|
||||
|
||||
self.request_queue.task_done()
|
||||
self.get_event.set()
|
||||
@@ -0,0 +1,272 @@
|
||||
import math
|
||||
import os
|
||||
import pickle
|
||||
from dataclasses import dataclass
|
||||
from multiprocessing.shared_memory import SharedMemory
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import vllm.envs as envs
|
||||
import zmq
|
||||
from vllm.config import KVTransferConfig, VllmConfig
|
||||
from vllm.logger import logger
|
||||
from vllm.utils.network_utils import make_zmq_socket
|
||||
from vllm.utils.torch_utils import get_dtype_size
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec
|
||||
|
||||
from vllm_ascend.distributed.kv_transfer.kv_pool.cpu_offload.cpu_kv_cache_manager import \
|
||||
CPUKVCacheManager
|
||||
|
||||
|
||||
@dataclass
|
||||
class MLAConfig:
|
||||
nope_dim: int
|
||||
rope_dim: int
|
||||
|
||||
|
||||
def get_cpu_offload_connector(vllm_config: VllmConfig) -> KVTransferConfig:
|
||||
if vllm_config.kv_transfer_config is not None:
|
||||
kv_transfer_config = vllm_config.kv_transfer_config
|
||||
if kv_transfer_config.kv_connector == "CPUOffloadingConnector":
|
||||
return kv_transfer_config
|
||||
elif kv_transfer_config.kv_connector == "MultiConnector":
|
||||
ktcs = kv_transfer_config.kv_connector_extra_config.get(
|
||||
"connectors")
|
||||
for ktc in ktcs:
|
||||
kv_transfer_config = KVTransferConfig(**ktc)
|
||||
if kv_transfer_config.kv_connector == "CPUOffloadingConnector":
|
||||
return kv_transfer_config
|
||||
return None
|
||||
|
||||
|
||||
class MetadataServer:
|
||||
METADATA_SERVER_ADDRESS = f"ipc://{envs.VLLM_RPC_BASE_PATH}/metadata.ipc"
|
||||
DEFAULT_CPU_SWAP_SPACE_GB = 800
|
||||
|
||||
class ZMQRPCClient:
|
||||
|
||||
def __init__(self, identity=f"worker-{os.getpid()}"):
|
||||
logger.info(f"metadata client for worker {identity} started")
|
||||
self.ctx = zmq.Context() # type: ignore
|
||||
self.socket = make_zmq_socket(
|
||||
self.ctx,
|
||||
MetadataServer.METADATA_SERVER_ADDRESS,
|
||||
zmq.DEALER, # type: ignore
|
||||
bind=False,
|
||||
identity=identity.encode(),
|
||||
linger=0)
|
||||
|
||||
def call(self, func_name: str, *args, **kwargs) -> Any:
|
||||
request = (func_name, args, kwargs)
|
||||
self.socket.send(b"", zmq.SNDMORE) # type: ignore
|
||||
self.socket.send(pickle.dumps(request))
|
||||
_ = self.socket.recv()
|
||||
response = pickle.loads(self.socket.recv())
|
||||
result, error = response
|
||||
if error:
|
||||
logger.exception(f"call metadata sever error: {error}")
|
||||
raise error
|
||||
if func_name == "init_cpu_kv_caches":
|
||||
(memory_dict, layer_size, layer_dtype, mla_config) = result
|
||||
# shared_memory_dict is recorded in self to close
|
||||
self.shared_memory_dict = memory_dict
|
||||
result = {}
|
||||
for key, shm in memory_dict.items():
|
||||
tensor = torch.frombuffer(
|
||||
shm.buf, dtype=layer_dtype).reshape(layer_size)
|
||||
if mla_config is not None:
|
||||
tensor = tensor.split(
|
||||
[mla_config.nope_dim, mla_config.rope_dim], dim=-1)
|
||||
result[key] = tensor
|
||||
return result
|
||||
|
||||
def __del__(self):
|
||||
# will be finalized by outer process
|
||||
self.socket.close()
|
||||
self.ctx.term()
|
||||
if hasattr(self, 'shared_memory_dict'):
|
||||
for shm in self.shared_memory_dict.values():
|
||||
shm.close()
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
self.world_size = vllm_config.parallel_config.world_size
|
||||
self.pipeline_parallel_size = vllm_config.parallel_config.pipeline_parallel_size
|
||||
kv_transfer_config = get_cpu_offload_connector(vllm_config)
|
||||
assert kv_transfer_config is not None
|
||||
available_memory_gb = kv_transfer_config.get_from_extra_config(
|
||||
"cpu_swap_space_gb", MetadataServer.DEFAULT_CPU_SWAP_SPACE_GB)
|
||||
self.available_memory = available_memory_gb * 1024 * 1024 * 1024
|
||||
logger.info(f"cpu swap space: {self.available_memory} bytes")
|
||||
self.ctx = zmq.Context() # type: ignore
|
||||
self.socket = make_zmq_socket(
|
||||
self.ctx,
|
||||
MetadataServer.METADATA_SERVER_ADDRESS,
|
||||
zmq.ROUTER, # type: ignore
|
||||
bind=True,
|
||||
linger=0)
|
||||
self.functions: dict[str, Callable] = {
|
||||
"init_cpu_kv_caches": self.init_cpu_kv_caches,
|
||||
"post_init": self.post_init,
|
||||
"ready": self.ready,
|
||||
}
|
||||
self.shared_memory = {} # type: ignore
|
||||
self.num_cpu_blocks = -1
|
||||
|
||||
@staticmethod
|
||||
def _safe_create_shared_memory(name: str, size: int) -> SharedMemory:
|
||||
try:
|
||||
existing_shm = SharedMemory(name=name, create=False)
|
||||
existing_shm.close()
|
||||
existing_shm.unlink()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
return SharedMemory(name=name, create=True, size=size)
|
||||
|
||||
def ready(self):
|
||||
return True
|
||||
|
||||
def init_cpu_kv_caches(
|
||||
self,
|
||||
pp_rank: int,
|
||||
tp_rank: int,
|
||||
kv_cache_specs: dict[str, AttentionSpec],
|
||||
mla_config: MLAConfig,
|
||||
) -> tuple[dict[str, SharedMemory], tuple[int, ...], torch.dtype,
|
||||
MLAConfig]:
|
||||
logger.info(f"receive pp rank: {pp_rank}, tp rank: {tp_rank}")
|
||||
# follow the assumption that each layer has the same spec
|
||||
layer = next(iter(kv_cache_specs.values()))
|
||||
assert all([
|
||||
layer.page_size_bytes == any.page_size_bytes
|
||||
for any in kv_cache_specs.values()
|
||||
])
|
||||
use_mla = isinstance(layer, MLAAttentionSpec)
|
||||
# mla shares the same kv cache among different tp
|
||||
if use_mla:
|
||||
tp_rank = 0
|
||||
if (pp_rank, tp_rank) in self.shared_memory:
|
||||
return self.shared_memory[(pp_rank, tp_rank)]
|
||||
available_memory = self.available_memory
|
||||
shared_memory_dict = {}
|
||||
if use_mla:
|
||||
available_memory //= self.pipeline_parallel_size
|
||||
available_memory //= len(kv_cache_specs)
|
||||
num_blocks = available_memory // layer.page_size_bytes
|
||||
layer_size = (num_blocks, layer.block_size, layer.num_kv_heads,
|
||||
layer.head_size) # type: ignore
|
||||
else:
|
||||
available_memory //= self.world_size
|
||||
available_memory //= len(kv_cache_specs)
|
||||
num_blocks = available_memory // layer.page_size_bytes
|
||||
layer_size = (2, num_blocks, layer.block_size, layer.num_kv_heads,
|
||||
layer.head_size) # type: ignore
|
||||
nbytes = math.prod(layer_size) * get_dtype_size(layer.dtype)
|
||||
for layer_name in kv_cache_specs.keys():
|
||||
# only this format can share during ZeroMQ+pickle
|
||||
shared_memory_dict[
|
||||
layer_name] = MetadataServer._safe_create_shared_memory(
|
||||
f"cpu_kv_cache_{pp_rank}_{tp_rank}_{layer_name}", nbytes)
|
||||
if use_mla:
|
||||
assert mla_config is not None
|
||||
assert layer.head_size == mla_config.rope_dim + mla_config.nope_dim
|
||||
self.shared_memory[(pp_rank,
|
||||
tp_rank)] = (shared_memory_dict, layer_size,
|
||||
layer.dtype, mla_config)
|
||||
else:
|
||||
self.shared_memory[(pp_rank,
|
||||
tp_rank)] = (shared_memory_dict, layer_size,
|
||||
layer.dtype, None)
|
||||
if self.num_cpu_blocks == -1 or num_blocks < self.num_cpu_blocks:
|
||||
self.num_cpu_blocks = num_blocks
|
||||
self.layer = layer
|
||||
return self.shared_memory[(pp_rank, tp_rank)]
|
||||
|
||||
def post_init(self):
|
||||
# different processors in data parallel may call multiple times
|
||||
if hasattr(self, 'cpu_block_manager'):
|
||||
return
|
||||
# do shared_memory() at least once
|
||||
logger.info(f"assign cpu num blocks: {self.num_cpu_blocks}")
|
||||
assert self.num_cpu_blocks >= 0
|
||||
self.cpu_block_manager = CPUKVCacheManager(self.layer,
|
||||
self.num_cpu_blocks)
|
||||
self.functions.update({
|
||||
"get_matched_num_and_touch":
|
||||
self.cpu_block_manager.get_matched_num_and_touch,
|
||||
"allocate_slots":
|
||||
self.cpu_block_manager.allocate_slots,
|
||||
"record_request_cache_and_free_slots":
|
||||
self.cpu_block_manager.record_request_cache_and_free_slots,
|
||||
"cache_and_free_slots":
|
||||
self.cpu_block_manager.cache_and_free_slots,
|
||||
})
|
||||
|
||||
def serve_step(self):
|
||||
client_id = self.socket.recv()
|
||||
_ = self.socket.recv()
|
||||
raw_msg = self.socket.recv()
|
||||
try:
|
||||
func_name, args, kwargs = pickle.loads(raw_msg)
|
||||
except Exception as e:
|
||||
response = (None, Exception(f"Invalid request: {str(e)}"))
|
||||
else:
|
||||
if func_name in self.functions:
|
||||
try:
|
||||
result = self.functions[func_name](*args, **kwargs)
|
||||
response = (result, None) # type: ignore
|
||||
except Exception as e:
|
||||
logger.exception(f"metadata execute error: {e}")
|
||||
response = (None, e) # type: ignore
|
||||
else:
|
||||
response = (None, NameError(f"Function {func_name} not found"))
|
||||
self.socket.send(client_id, zmq.SNDMORE) # type: ignore
|
||||
self.socket.send(b"", zmq.SNDMORE) # type: ignore
|
||||
self.socket.send(pickle.dumps(response))
|
||||
|
||||
def shutdown(self):
|
||||
self.socket.close()
|
||||
self.ctx.term()
|
||||
socket_path = MetadataServer.METADATA_SERVER_ADDRESS.replace(
|
||||
"ipc://", "")
|
||||
if os.path.exists(socket_path):
|
||||
os.remove(socket_path)
|
||||
for cached in self.shared_memory.values():
|
||||
for shm in cached[0].values():
|
||||
shm.close()
|
||||
shm.unlink()
|
||||
|
||||
|
||||
class MetadataServerProc:
|
||||
|
||||
@staticmethod
|
||||
def run_metadata_server(vllm_config: VllmConfig):
|
||||
if (not vllm_config.cache_config.enable_prefix_caching
|
||||
or get_cpu_offload_connector(vllm_config) is None):
|
||||
return
|
||||
|
||||
shutdown_requested = False
|
||||
|
||||
def _signal_handler(signum, frame):
|
||||
nonlocal shutdown_requested
|
||||
if not shutdown_requested:
|
||||
shutdown_requested = True
|
||||
raise SystemExit()
|
||||
|
||||
# Either SIGTERM or SIGINT will terminate the worker
|
||||
# signal.signal(signal.SIGTERM, _signal_handler)
|
||||
# signal.signal(signal.SIGINT, _signal_handler)
|
||||
metadata_server: Optional[MetadataServer] = None
|
||||
try:
|
||||
metadata_server = MetadataServer(vllm_config)
|
||||
logger.info("Metadata server started.")
|
||||
while True:
|
||||
metadata_server.serve_step()
|
||||
except SystemExit:
|
||||
logger.info("Metadata server exiting.")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"Metadata server error: {e}.")
|
||||
raise e
|
||||
finally:
|
||||
if metadata_server is not None:
|
||||
metadata_server.shutdown()
|
||||
@@ -0,0 +1,331 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
import vllm.envs as envs
|
||||
import zmq
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import \
|
||||
KVConnectorMetadata
|
||||
from vllm.logger import logger
|
||||
from vllm.utils.network_utils import make_zmq_socket
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.request import Request
|
||||
from vllm.v1.serial_utils import MsgpackEncoder
|
||||
|
||||
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import (
|
||||
AscendConnectorMetadata, LoadSpec, ReqMeta, RequestTracker)
|
||||
|
||||
|
||||
class KVPoolScheduler:
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig", use_layerwise):
|
||||
self.use_layerwise = use_layerwise
|
||||
self.kv_role = vllm_config.kv_transfer_config.kv_role
|
||||
self.consumer_is_to_load = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"consumer_is_to_load", False)
|
||||
self.consumer_is_to_put = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"consumer_is_to_put", False)
|
||||
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"load_async", False)
|
||||
self.client = LookupKeyClient(vllm_config)
|
||||
# request_id -> (vllm cached tokes, kvpool cached tokens)
|
||||
self.load_specs: dict[str, LoadSpec] = {}
|
||||
self.pcp_size = getattr(vllm_config.parallel_config,
|
||||
"prefill_context_parallel_size", 1)
|
||||
self.dcp_size = getattr(vllm_config.parallel_config,
|
||||
"decode_context_parallel_size", 1)
|
||||
|
||||
self._block_size = vllm_config.cache_config.block_size
|
||||
if self.pcp_size > 1:
|
||||
self._block_size *= self.pcp_size
|
||||
if self.dcp_size > 1:
|
||||
self._block_size *= self.dcp_size
|
||||
# request_id -> full_token_ids
|
||||
self._request_trackers: dict[str, RequestTracker] = {}
|
||||
# Whether to discard partial chunks
|
||||
self._discard_partial_chunks = (
|
||||
vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
"discard_partial_chunks", True))
|
||||
self._unfinished_requests: dict[str, tuple[Request, list[int]]] = {}
|
||||
self._unfinished_request_ids: set[str] = set()
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int, bool]:
|
||||
"""
|
||||
Check for external KV cache hit.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
|
||||
Returns:
|
||||
the number of tokens that can be loaded from the
|
||||
external KV cache beyond what is already computed.
|
||||
"""
|
||||
if self.kv_role == "kv_consumer" and not self.consumer_is_to_load:
|
||||
return 0, False
|
||||
|
||||
if self._discard_partial_chunks:
|
||||
token_len = len(request.prompt_token_ids
|
||||
) // self._block_size * self._block_size
|
||||
else:
|
||||
token_len = len(request.prompt_token_ids)
|
||||
|
||||
num_external_hit_tokens = self.client.lookup(token_len,
|
||||
request.block_hashes)
|
||||
|
||||
if num_external_hit_tokens == request.num_tokens:
|
||||
num_external_hit_tokens -= 1
|
||||
|
||||
if num_external_hit_tokens < num_computed_tokens:
|
||||
need_to_allocate = 0
|
||||
else:
|
||||
need_to_allocate = num_external_hit_tokens - num_computed_tokens
|
||||
|
||||
logger.info(
|
||||
"Reqid: %s, Total tokens %d, kvpool hit tokens: %d, need to load: %d",
|
||||
request.request_id,
|
||||
request.num_tokens,
|
||||
num_external_hit_tokens,
|
||||
need_to_allocate,
|
||||
)
|
||||
|
||||
if need_to_allocate <= 0:
|
||||
return 0, False
|
||||
|
||||
self.load_specs[request.request_id] = LoadSpec(
|
||||
vllm_cached_tokens=num_computed_tokens,
|
||||
kvpool_cached_tokens=num_external_hit_tokens,
|
||||
can_load=False,
|
||||
)
|
||||
|
||||
return need_to_allocate, self.load_async and not self.use_layerwise
|
||||
|
||||
def update_state_after_alloc(self, request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int):
|
||||
"""
|
||||
Update KVConnector state after temporary buffer alloc.
|
||||
|
||||
For SharedStorageConnector, update _request_needs_load
|
||||
if the CacheManager this allocated blocks for us.
|
||||
"""
|
||||
local_block_ids = []
|
||||
if num_external_tokens > 0:
|
||||
local_block_ids = blocks.get_block_ids()[0]
|
||||
|
||||
self._unfinished_requests[request.request_id] = (request,
|
||||
local_block_ids)
|
||||
self._unfinished_request_ids.add(request.request_id)
|
||||
if request.request_id not in self.load_specs:
|
||||
# No KV tokens from external KV cache, return
|
||||
return
|
||||
|
||||
if num_external_tokens == 0:
|
||||
# No need to load anything
|
||||
self.load_specs[request.request_id].can_load = False
|
||||
return
|
||||
|
||||
assert (
|
||||
num_external_tokens > 0 and num_external_tokens
|
||||
== self.load_specs[request.request_id].kvpool_cached_tokens -
|
||||
self.load_specs[request.request_id].vllm_cached_tokens
|
||||
), (f"Mismatch in number of tokens: {num_external_tokens} vs "
|
||||
f"{self.load_specs[request.request_id].kvpool_cached_tokens} - "
|
||||
f"{self.load_specs[request.request_id].vllm_cached_tokens}"
|
||||
f" for request {request.request_id}")
|
||||
|
||||
self.load_specs[request.request_id].can_load = True
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
|
||||
"""Attach the connector metadata to the request object.
|
||||
|
||||
This function should NOT modify other fields in the scheduler_output
|
||||
except the `kv_connector_metadata` field.
|
||||
Also, calling this function will reset the state of the connector.
|
||||
|
||||
Args:
|
||||
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||
"""
|
||||
|
||||
force_skip_save = (self.kv_role == "kv_consumer"
|
||||
and not self.consumer_is_to_put)
|
||||
|
||||
for finished_req_id in scheduler_output.finished_req_ids:
|
||||
self._request_trackers.pop(finished_req_id, None)
|
||||
self._unfinished_requests.pop(finished_req_id, None)
|
||||
self._unfinished_request_ids.discard(finished_req_id)
|
||||
|
||||
meta = AscendConnectorMetadata(self._unfinished_request_ids)
|
||||
|
||||
for request in scheduler_output.scheduled_new_reqs:
|
||||
# Right now, we only load KV for new requests
|
||||
load_spec = self.load_specs.pop(request.req_id, None)
|
||||
num_tokens_to_compute = (
|
||||
request.num_computed_tokens +
|
||||
scheduler_output.num_scheduled_tokens[request.req_id])
|
||||
request_tracker = RequestTracker.from_new_request(
|
||||
request, num_tokens_to_compute)
|
||||
self._request_trackers[request.req_id] = request_tracker
|
||||
last_chunk_tokens_num = ((len(request.prompt_token_ids) //
|
||||
self._block_size * self._block_size)
|
||||
if self._discard_partial_chunks else len(
|
||||
request.prompt_token_ids))
|
||||
request_tuple = self._unfinished_requests.get(request.req_id)
|
||||
request_real = request_tuple[0] # type: ignore[index]
|
||||
req_meta = ReqMeta.from_request_tracker(
|
||||
request_tracker,
|
||||
self._block_size,
|
||||
load_spec=load_spec,
|
||||
skip_save=force_skip_save,
|
||||
block_hashes=request_real.block_hashes,
|
||||
is_last_chunk=request_tracker.token_len
|
||||
>= last_chunk_tokens_num,
|
||||
discard_partial_chunks=self._discard_partial_chunks,
|
||||
)
|
||||
if req_meta is not None:
|
||||
meta.add_request(req_meta)
|
||||
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
if not force_skip_save:
|
||||
for i, req_id in enumerate(cached_reqs.req_ids):
|
||||
request_tracker = self._request_trackers[req_id]
|
||||
num_new_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
req_tuple = self._unfinished_requests.get(req_id)
|
||||
if req_tuple:
|
||||
request = req_tuple[0]
|
||||
num_current_tokens = request_tracker.token_len
|
||||
new_token_ids = request.all_token_ids[
|
||||
num_current_tokens:num_current_tokens + num_new_tokens]
|
||||
request_tracker.token_len += len(new_token_ids)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Request {req_id} is not in _unfinished_requests, "
|
||||
f"but it is scheduled to be cached")
|
||||
new_block_ids = cached_reqs.new_block_ids[i]
|
||||
if not new_block_ids:
|
||||
continue
|
||||
request_tracker.update(new_block_ids)
|
||||
|
||||
last_chunk_tokens_num = ((len(request.prompt_token_ids) //
|
||||
self._block_size * self._block_size)
|
||||
if self._discard_partial_chunks else
|
||||
len(request.prompt_token_ids))
|
||||
req_meta = ReqMeta.from_request_tracker(
|
||||
request_tracker,
|
||||
self._block_size,
|
||||
load_spec=None,
|
||||
skip_save=force_skip_save,
|
||||
block_hashes=request.block_hashes,
|
||||
is_last_chunk=request_tracker.token_len
|
||||
>= last_chunk_tokens_num,
|
||||
discard_partial_chunks=self._discard_partial_chunks,
|
||||
)
|
||||
if req_meta is not None:
|
||||
meta.add_request(req_meta)
|
||||
|
||||
request_ids = [
|
||||
req.req_id for req in scheduler_output.scheduled_new_reqs
|
||||
]
|
||||
for request_id, (request,
|
||||
block_ids) in self._unfinished_requests.items():
|
||||
if request_id not in request_ids and request_id not in cached_reqs.req_ids:
|
||||
load_spec = self.load_specs.pop(request_id, None)
|
||||
if not load_spec:
|
||||
continue
|
||||
num_tokens_to_compute = load_spec.kvpool_cached_tokens
|
||||
if (num_tokens_to_compute % self._block_size
|
||||
!= 0) and (num_tokens_to_compute
|
||||
== len(request.prompt_token_ids) - 1):
|
||||
num_tokens_to_compute = num_tokens_to_compute + 1
|
||||
request_tracker = RequestTracker(
|
||||
req_id=request_id,
|
||||
token_len=num_tokens_to_compute,
|
||||
allocated_block_ids=block_ids,
|
||||
num_saved_tokens=0,
|
||||
)
|
||||
|
||||
self._request_trackers[request_id] = request_tracker
|
||||
|
||||
req_meta = ReqMeta.from_request_tracker(
|
||||
request_tracker,
|
||||
self._block_size,
|
||||
load_spec=load_spec,
|
||||
skip_save=None,
|
||||
block_hashes=request.block_hashes,
|
||||
discard_partial_chunks=self._discard_partial_chunks,
|
||||
)
|
||||
if req_meta is not None:
|
||||
meta.add_request(req_meta)
|
||||
return meta
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||
"""
|
||||
Once a request is finished, determine whether request blocks
|
||||
should be freed now or will be sent asynchronously and freed later.
|
||||
"""
|
||||
if self.kv_role == "kv_consumer" and not self.consumer_is_to_put:
|
||||
return False, None
|
||||
tracker = self._request_trackers.get(request.request_id)
|
||||
if tracker is not None and tracker.num_saved_tokens <= 0:
|
||||
return False, None
|
||||
delay_free_blocks = len(block_ids) > 0
|
||||
if delay_free_blocks:
|
||||
logger.info("Delaying free of %d blocks for request %s",
|
||||
len(block_ids), request.request_id)
|
||||
return delay_free_blocks, None
|
||||
|
||||
|
||||
class LookupKeyClient:
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig"):
|
||||
self.encoder = MsgpackEncoder()
|
||||
self.ctx = zmq.Context() # type: ignore[attr-defined]
|
||||
socket_path = get_zmq_rpc_path_lookup(vllm_config)
|
||||
self.socket = make_zmq_socket(
|
||||
self.ctx,
|
||||
socket_path,
|
||||
zmq.REQ, # type: ignore[attr-defined]
|
||||
bind=False,
|
||||
)
|
||||
|
||||
def lookup(self, token_len: int, block_hashes: list[BlockHash]) -> int:
|
||||
hash_strs = [h.hex() for h in block_hashes]
|
||||
hash_frames = self.encoder.encode(hash_strs)
|
||||
token_len_bytes = token_len.to_bytes(4, byteorder="big")
|
||||
all_frames = [token_len_bytes] + list(hash_frames)
|
||||
self.socket.send_multipart(all_frames, copy=False)
|
||||
resp = self.socket.recv()
|
||||
result = int.from_bytes(resp, "big")
|
||||
return result
|
||||
|
||||
def close(self):
|
||||
self.socket.close(linger=0)
|
||||
|
||||
|
||||
def get_zmq_rpc_path_lookup(vllm_config: "VllmConfig") -> str:
|
||||
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
base_url = envs.VLLM_RPC_BASE_PATH
|
||||
# Default to 0 if not configured
|
||||
rpc_port = 0
|
||||
if vllm_config is not None:
|
||||
extra_config = vllm_config.kv_transfer_config.kv_connector_extra_config
|
||||
if "lookup_rpc_port" in extra_config:
|
||||
rpc_port = extra_config["lookup_rpc_port"]
|
||||
elif "mooncake_rpc_port" in extra_config:
|
||||
rpc_port = extra_config["mooncake_rpc_port"]
|
||||
logger.warning(
|
||||
"It is recommended to use the lookup_rpc_port, as the mooncake_rpc_port will be removed in the future."
|
||||
)
|
||||
logger.debug("Base URL: %s, RPC Port: %s", base_url, rpc_port)
|
||||
return f"ipc://{base_url}/lookup_rpc_port_{rpc_port}_dp_rank{dp_rank}"
|
||||
@@ -0,0 +1,621 @@
|
||||
import math
|
||||
import threading
|
||||
from typing import Dict, Generator, Optional, Type
|
||||
|
||||
import torch
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (get_decode_context_model_parallel_rank,
|
||||
get_decode_context_model_parallel_world_size,
|
||||
get_pcp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.logger import logger
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash
|
||||
|
||||
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import \
|
||||
Backend
|
||||
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.memcache_backend import \
|
||||
MemcacheBackend
|
||||
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.mooncake_backend import \
|
||||
MooncakeBackend
|
||||
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import (
|
||||
AscendConnectorMetadata, ChunkedTokenDatabase, KeyMetadata,
|
||||
LasyerMultiBlockReqMeta, ReqMeta)
|
||||
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.kv_transfer import (
|
||||
KVCacheStoreLayerRecvingThread, KVCacheStoreLayerSendingThread,
|
||||
KVCacheStoreRecvingThread, KVCacheStoreSendingThread, KVTransferThread)
|
||||
|
||||
backend_map: Dict[str, Type[Backend]] = {
|
||||
"mooncake": MooncakeBackend,
|
||||
"memcache": MemcacheBackend,
|
||||
}
|
||||
|
||||
|
||||
class KVPoolWorker:
|
||||
#The main class for the cache engine.
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
use_layerwize: bool,
|
||||
):
|
||||
model_config = vllm_config.model_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
self.dp_rank = parallel_config.data_parallel_rank
|
||||
self.use_mla = False
|
||||
if (hasattr(model_config, "use_mla")
|
||||
and isinstance(model_config.use_mla, bool)
|
||||
and model_config.use_mla):
|
||||
self.use_mla = True
|
||||
self.use_layerwise = use_layerwize
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.pp_size = parallel_config.pipeline_parallel_size
|
||||
self.pp_rank = (parallel_config.rank // self.tp_size) % self.pp_size
|
||||
|
||||
self.pcp_size = get_pcp_group().world_size
|
||||
self.pcp_rank = get_pcp_group(
|
||||
).rank_in_group if self.pcp_size > 1 else 0
|
||||
self.dcp_size = get_decode_context_model_parallel_world_size()
|
||||
self.dcp_rank = get_decode_context_model_parallel_rank(
|
||||
) if self.dcp_size > 1 else 0
|
||||
|
||||
self.kv_role = vllm_config.kv_transfer_config.kv_role
|
||||
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"load_async", False)
|
||||
self.consumer_is_to_put = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"consumer_is_to_put", False)
|
||||
self.backend = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"backend", "mooncake")
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
|
||||
if self.pcp_size > 1:
|
||||
self.block_size *= self.pcp_size
|
||||
if self.dcp_size > 1:
|
||||
self.block_size *= self.dcp_size
|
||||
self.current_layer = 0
|
||||
self.num_layers = model_config.get_num_layers(parallel_config)
|
||||
|
||||
if self.use_mla:
|
||||
self.num_kv_head = 1
|
||||
else:
|
||||
self.num_kv_head = model_config.get_total_num_kv_heads()
|
||||
|
||||
if self.num_kv_head < self.tp_size:
|
||||
self.put_step = self.tp_size // self.num_kv_head
|
||||
self.head_or_tp_rank = self.tp_rank // self.put_step
|
||||
else:
|
||||
self.head_or_tp_rank = self.tp_rank
|
||||
self.put_step = 1
|
||||
|
||||
self.metadata = KeyMetadata(
|
||||
model_config.model.rstrip('/').split('/')[-1],
|
||||
self.head_or_tp_rank,
|
||||
self.pcp_rank,
|
||||
self.dcp_rank,
|
||||
self.pp_rank,
|
||||
)
|
||||
|
||||
partitions = None
|
||||
if self.kv_role == "kv_consumer" and self.consumer_is_to_put:
|
||||
num_hidden_layers = model_config.hf_text_config.num_hidden_layers
|
||||
partition_list_str = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"prefill_pp_layer_partition", None)
|
||||
prefill_pp_size = int(
|
||||
vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"prefill_pp_size", 1))
|
||||
|
||||
if partition_list_str is not None:
|
||||
try:
|
||||
partitions = [
|
||||
int(layer) for layer in partition_list_str.split(",")
|
||||
]
|
||||
except ValueError as err:
|
||||
raise ValueError("Invalid partition string: {}".format(
|
||||
partition_list_str)) from err
|
||||
if len(partitions) != prefill_pp_size:
|
||||
raise ValueError(
|
||||
f"{len(partitions)=} does not match {prefill_pp_size=}."
|
||||
)
|
||||
if sum(partitions) != num_hidden_layers:
|
||||
raise ValueError(
|
||||
f"{sum(partitions)=} does not match {num_hidden_layers=}."
|
||||
)
|
||||
else:
|
||||
layers_per_partition = num_hidden_layers // prefill_pp_size
|
||||
partitions = [
|
||||
layers_per_partition for _ in range(prefill_pp_size)
|
||||
]
|
||||
|
||||
if remaining_layers := num_hidden_layers % prefill_pp_size:
|
||||
for i in range(2, remaining_layers + 2):
|
||||
partitions[-i] += 1
|
||||
|
||||
self.token_database = ChunkedTokenDatabase(self.metadata,
|
||||
self.block_size,
|
||||
self.use_mla, partitions)
|
||||
|
||||
real_backend = backend_map.get(self.backend.lower())
|
||||
|
||||
# be removed later
|
||||
if self.backend == "mooncake":
|
||||
self.head_or_tp_rank = self.tp_rank
|
||||
self.put_step = 1
|
||||
|
||||
self.m_store = real_backend( # type: ignore[misc]
|
||||
parallel_config)
|
||||
|
||||
self.kv_send_thread: Optional[KVTransferThread] = None
|
||||
self.kv_recv_thread: Optional[KVTransferThread] = None
|
||||
|
||||
self.finished_store_req: set[str] = set()
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
_, first_kv_cache_tuple = next(iter(kv_caches.items()))
|
||||
first_kv_cache = first_kv_cache_tuple[0]
|
||||
|
||||
# TODO(tms): Find a more robust way to detect and handle MLA
|
||||
if self.use_mla:
|
||||
# MLA case.[num_block, block_size, 1, hidden_dim]
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
block_rank = 3 # [block_size, latent_dim]
|
||||
block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:]
|
||||
block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:]
|
||||
self.block_len = [
|
||||
first_kv_cache[0].element_size() * math.prod(block_shape_norm),
|
||||
first_kv_cache[1].element_size() * math.prod(block_shape_pe)
|
||||
]
|
||||
logger.info(
|
||||
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s",
|
||||
self.num_blocks, block_shape_norm, block_shape_pe)
|
||||
else:
|
||||
# [num_block, block_size, num_head, hidden_dim]
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
kv_elem_size = first_kv_cache.element_size()
|
||||
block_rank = 3 # [block_size, kv_heads, head_dim]
|
||||
block_shape = first_kv_cache.shape[-block_rank:]
|
||||
self.block_len = [kv_elem_size * math.prod(block_shape)]
|
||||
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks,
|
||||
block_shape)
|
||||
|
||||
logger.info("Registering KV_Caches. use_mla: %s, shape %s",
|
||||
self.use_mla, first_kv_cache.shape)
|
||||
|
||||
self.kv_caches = kv_caches
|
||||
self.kv_caches_base_addr = []
|
||||
ptrs = []
|
||||
lengths = []
|
||||
for cache_or_caches in kv_caches.values():
|
||||
# Normalize to always be a list of caches
|
||||
if self.use_mla:
|
||||
for i, cache in enumerate(cache_or_caches, 0):
|
||||
base_addr = cache.data_ptr()
|
||||
self.kv_caches_base_addr.append(base_addr)
|
||||
region_len = self.num_blocks * self.block_len[i % 2]
|
||||
ptrs.append(base_addr)
|
||||
lengths.append(region_len)
|
||||
else:
|
||||
cache_list = [cache_or_caches
|
||||
] if self.use_mla else cache_or_caches
|
||||
for cache in cache_list:
|
||||
base_addr = cache.data_ptr()
|
||||
self.kv_caches_base_addr.append(base_addr)
|
||||
region_len = self.num_blocks * self.block_len[0]
|
||||
ptrs.append(base_addr)
|
||||
lengths.append(region_len)
|
||||
self.m_store.register_buffer(ptrs, lengths)
|
||||
self.token_database.set_kv_caches_base_addr(self.kv_caches_base_addr)
|
||||
self.token_database.set_block_len(self.block_len)
|
||||
|
||||
if self.use_layerwise:
|
||||
self.get_event = threading.Event()
|
||||
if self.kv_role in ['kv_producer', 'kv_both']:
|
||||
ready_event_sending = threading.Event()
|
||||
self.kv_send_thread = KVCacheStoreLayerSendingThread(
|
||||
self.m_store, self.token_database, self.block_size,
|
||||
self.tp_rank, self.dcp_size, self.put_step,
|
||||
ready_event_sending, self.num_layers)
|
||||
self.kv_send_thread.start()
|
||||
ready_event = threading.Event()
|
||||
self.kv_recv_thread = KVCacheStoreLayerRecvingThread(
|
||||
self.m_store, self.token_database, self.block_size,
|
||||
self.tp_rank, self.dcp_size, ready_event, self.get_event)
|
||||
self.kv_recv_thread.start()
|
||||
ready_event.wait()
|
||||
else:
|
||||
if self.kv_role in ['kv_producer', 'kv_both'
|
||||
] or self.consumer_is_to_put:
|
||||
ready_event_sending = threading.Event()
|
||||
self.kv_send_thread = KVCacheStoreSendingThread(
|
||||
self.m_store, self.token_database, self.block_size,
|
||||
self.tp_rank, self.dcp_size, self.put_step, self.kv_role,
|
||||
ready_event_sending)
|
||||
self.kv_send_thread.start()
|
||||
if self.load_async:
|
||||
ready_event = threading.Event()
|
||||
self.kv_recv_thread = KVCacheStoreRecvingThread(
|
||||
self.m_store, self.token_database, self.block_size,
|
||||
self.tp_rank, self.dcp_size, ready_event)
|
||||
self.kv_recv_thread.start()
|
||||
ready_event.wait()
|
||||
|
||||
def start_load_kv(self, metadata: AscendConnectorMetadata):
|
||||
self.current_layer = 0
|
||||
self.layerwise_retrievers = []
|
||||
for request in metadata.requests:
|
||||
load_spec = request.load_spec
|
||||
if load_spec is None or not load_spec.can_load: #load =0
|
||||
continue
|
||||
token_len = request.token_len_chunk
|
||||
if (load_spec.kvpool_cached_tokens % self.block_size
|
||||
!= 0) and (load_spec.kvpool_cached_tokens
|
||||
== token_len - 1):
|
||||
token_len = request.load_spec.kvpool_cached_tokens + 1
|
||||
else:
|
||||
token_len = request.load_spec.kvpool_cached_tokens
|
||||
request.load_spec.token_len = token_len
|
||||
if self.use_layerwise:
|
||||
layerwise_retriever = self.retrieve_layer(request)
|
||||
next(layerwise_retriever) # first layer load
|
||||
self.layerwise_retrievers.append(layerwise_retriever)
|
||||
else:
|
||||
if self.load_async:
|
||||
self.kv_recv_thread.add_request( # type: ignore[union-attr]
|
||||
request, )
|
||||
else:
|
||||
addr_list = []
|
||||
size_list = []
|
||||
key_list = []
|
||||
mask_num = (request.load_spec.vllm_cached_tokens //
|
||||
self.block_size * self.block_size)
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
token_len, request.block_hashes, mask_num):
|
||||
addr, size, _ = self.token_database.prepare_value(
|
||||
start, end, request.block_ids)
|
||||
key_list.append(key.to_string())
|
||||
addr_list.append(addr)
|
||||
size_list.append(size)
|
||||
key_list_c = key_list[self.tp_rank % len(
|
||||
key_list):] + key_list[:self.tp_rank % len(key_list)]
|
||||
addr_list_c = addr_list[self.tp_rank %
|
||||
len(addr_list
|
||||
):] + addr_list[:self.tp_rank %
|
||||
len(addr_list)]
|
||||
size_list_c = size_list[self.tp_rank %
|
||||
len(size_list
|
||||
):] + size_list[:self.tp_rank %
|
||||
len(size_list)]
|
||||
self.m_store.get(key_list_c, addr_list_c, size_list_c)
|
||||
|
||||
def wait_for_layer_load(self) -> None:
|
||||
for layerwise_retriever in self.layerwise_retrievers:
|
||||
ret_token_mask = next(layerwise_retriever)
|
||||
if self.current_layer == self.num_layers - 1:
|
||||
assert ret_token_mask is not None
|
||||
num_retrieved_tokens = ret_token_mask.sum().item()
|
||||
logger.debug(f"Retrieved {num_retrieved_tokens} tokens")
|
||||
|
||||
def save_kv_layer(self,
|
||||
connector_metadata: AscendConnectorMetadata) -> None:
|
||||
if self.current_layer == 0:
|
||||
self.layerwise_storers = []
|
||||
current_event = None
|
||||
for request in connector_metadata.requests:
|
||||
can_save = request.can_save
|
||||
if can_save is None or not can_save:
|
||||
continue
|
||||
current_event = torch.npu.Event()
|
||||
current_event.record()
|
||||
break
|
||||
for request in connector_metadata.requests:
|
||||
can_save = request.can_save
|
||||
if can_save is None or not can_save:
|
||||
continue
|
||||
|
||||
layerwise_storer = self.store_layer(request, current_event)
|
||||
self.layerwise_storers.append(layerwise_storer)
|
||||
for layerwise_storer in self.layerwise_storers:
|
||||
try:
|
||||
next(layerwise_storer)
|
||||
except Exception:
|
||||
raise
|
||||
self.current_layer = self.current_layer + 1
|
||||
|
||||
def wait_for_save(self, connector_metadata: AscendConnectorMetadata):
|
||||
current_event = None
|
||||
for request in connector_metadata.requests:
|
||||
can_save = request.can_save
|
||||
if can_save is None or not can_save:
|
||||
continue
|
||||
current_event = torch.npu.Event()
|
||||
current_event.record()
|
||||
break
|
||||
|
||||
for request in connector_metadata.requests:
|
||||
can_save = request.can_save
|
||||
if can_save is None or not can_save:
|
||||
continue
|
||||
|
||||
request.current_event = current_event
|
||||
self.kv_send_thread.add_stored_request( # type: ignore[union-attr]
|
||||
request.req_id)
|
||||
self.kv_send_thread.add_request( # type: ignore[union-attr]
|
||||
request, )
|
||||
|
||||
def retrieve_layer(
|
||||
self,
|
||||
request: ReqMeta,
|
||||
) -> Generator[Optional[torch.Tensor], None, None]:
|
||||
"""
|
||||
Retrieve the KV cache in a layerwise manner.
|
||||
|
||||
:param torch.Tensor tokens: The tokens of the corresponding KV caches.
|
||||
|
||||
:param Optional[torch.Tensor] mask: The mask for the tokens. Should
|
||||
have the same length as tokens. And the mask should ALWAYS be like
|
||||
FFFFFTTTTTTT, where True means the tokens needs to be matched.
|
||||
|
||||
:param **kwargs: The additional arguments for the KV transfer which
|
||||
will be passed into the npu_transfer.
|
||||
|
||||
return: A generator that yields Optional[torch.Tensor]. The tensor will
|
||||
be the boolean mask indicating which tokens are retrieved and will
|
||||
only be returned in the last iteration.
|
||||
"""
|
||||
token_len = request.token_len_chunk
|
||||
mask_num = (
|
||||
request.load_spec.vllm_cached_tokens # type: ignore[union-attr]
|
||||
// self.block_size * self.block_size)
|
||||
num_required_tokens = token_len - mask_num
|
||||
|
||||
ret_mask = torch.zeros(token_len, dtype=torch.bool, device="cpu")
|
||||
|
||||
starts = []
|
||||
ends = []
|
||||
keys = []
|
||||
first_flag = True
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
token_len, request.block_hashes, mask_num):
|
||||
keys_multi_layer = key.split_layers(self.num_layers)
|
||||
starts.append(start)
|
||||
ends.append(end)
|
||||
keys.append(keys_multi_layer)
|
||||
ret_mask[start:end] = True
|
||||
|
||||
if keys:
|
||||
# Transpose the keys into layer major format
|
||||
keys = [list(row) for row in zip(*keys)] # [num_layer,block_num]
|
||||
for layer_id, keys_multi_chunk in enumerate(keys):
|
||||
if not first_flag:
|
||||
is_finish = self.get_event.wait(timeout=3) #try---cache
|
||||
if not is_finish:
|
||||
logger.info("Layerwise get failed")
|
||||
self.get_event.clear()
|
||||
req_meta = LasyerMultiBlockReqMeta(request.req_id,
|
||||
keys_multi_chunk, starts,
|
||||
ends, request.block_ids,
|
||||
layer_id)
|
||||
self.kv_recv_thread.add_request( # type: ignore[union-attr, call-arg]
|
||||
req_meta) # type: ignore[union-attr, call-arg, arg-type]
|
||||
first_flag = False
|
||||
yield None
|
||||
else:
|
||||
# If no cache are found, we still need to yield to avoid
|
||||
# `StopIteration`
|
||||
for layer_id in range(self.num_layers):
|
||||
yield None
|
||||
|
||||
retrieved_tokens = torch.sum(ret_mask)
|
||||
logger.debug(f"Retrieved {retrieved_tokens} "
|
||||
f"out of {num_required_tokens} "
|
||||
f"out of total {token_len} tokens")
|
||||
|
||||
yield ret_mask
|
||||
|
||||
def store_layer(
|
||||
self,
|
||||
request: ReqMeta,
|
||||
current_event: Optional[torch.npu.Event],
|
||||
) -> Generator[None, None, None]:
|
||||
"""
|
||||
Store the KV cache in a layerwise manner.
|
||||
|
||||
:param torch.Tensor tokens: The tokens of the corresponding KV caches.
|
||||
|
||||
:param Optional[torch.Tensor] mask: The mask for the tokens. Should
|
||||
have the same length as tokens. And the mask should ALWAYS be like
|
||||
FFFFFTTTTTTT, where True means the tokens needs to be matched.
|
||||
|
||||
:param **kwargs: The additional arguments for the storage backend which
|
||||
will be passed into the gpu_connector.
|
||||
|
||||
return: A generator that yields None. In the first iteration, the
|
||||
generator allocates the memory objects for all layers and moves
|
||||
the KV cache of the first layer from GPU to CPU. In the next
|
||||
iterations, it moves the KV cache of layer i from GPU to the memory
|
||||
objects (on CPU) and puts the memory objects of layer i-1 to the
|
||||
storage backends. In the last iteration, it puts the memory objects
|
||||
of the last layer to the storage backends.
|
||||
"""
|
||||
starts = []
|
||||
ends = []
|
||||
keys = []
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
request.token_len_chunk, request.block_hashes):
|
||||
keys_multi_layer = key.split_layers(self.num_layers)
|
||||
starts.append(start)
|
||||
ends.append(end)
|
||||
keys.append(keys_multi_layer) #[block_num,layer_num]
|
||||
|
||||
if keys:
|
||||
keys = [list(row) for row in zip(*keys)] #[layer_num,block_num]
|
||||
for layer_id, keys_multi_chunk in enumerate(keys):
|
||||
req_meta = LasyerMultiBlockReqMeta(request.req_id,
|
||||
keys_multi_chunk, starts,
|
||||
ends, request.block_ids,
|
||||
layer_id,
|
||||
request.is_last_chunk,
|
||||
current_event)
|
||||
self.kv_send_thread.add_request( # type: ignore[union-attr, call-arg]
|
||||
req_meta) # type: ignore[union-attr, call-arg, arg-type]
|
||||
yield
|
||||
else:
|
||||
for layer_id in range(self.num_layers):
|
||||
yield
|
||||
|
||||
def get_finished(self,
|
||||
finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
|
||||
done_sending = (
|
||||
self.get_and_clear_finished_requests(
|
||||
finished_req_ids # type: ignore[union-attr]
|
||||
) if self.kv_role in ['kv_producer', 'kv_both']
|
||||
or self.consumer_is_to_put else set())
|
||||
|
||||
done_recving = (
|
||||
self.kv_recv_thread.
|
||||
get_and_clear_finished_requests( # type: ignore[union-attr]
|
||||
) if self.load_async else set())
|
||||
|
||||
logger.debug(
|
||||
"Number of completed KV cache send requests: %d, receive "
|
||||
"requests: %d, tp_rank:%d", len(done_sending), len(done_recving),
|
||||
self.tp_rank)
|
||||
return done_sending, done_recving
|
||||
|
||||
def get_and_clear_finished_requests(self, finished_req_ids) -> set[str]:
|
||||
finished_sending = set()
|
||||
for req_id in self.kv_send_thread.stored_requests.copy( # type: ignore[union-attr]
|
||||
):
|
||||
if self.kv_send_thread.stored_requests[ # type: ignore[union-attr]
|
||||
req_id] == 0 and req_id in self.finished_store_req:
|
||||
self.finished_store_req.remove(req_id)
|
||||
finished_sending.add(req_id)
|
||||
self.kv_send_thread.delete_finished_stored_request( # type: ignore[union-attr]
|
||||
req_id)
|
||||
|
||||
for req_id in finished_req_ids:
|
||||
req_remain_jobs = self.kv_send_thread.stored_requests.get( # type: ignore[union-attr]
|
||||
req_id)
|
||||
if req_remain_jobs == 0:
|
||||
finished_sending.add(req_id)
|
||||
self.kv_send_thread.delete_finished_stored_request( # type: ignore[union-attr]
|
||||
req_id)
|
||||
elif req_remain_jobs is not None:
|
||||
self.finished_store_req.add(req_id)
|
||||
|
||||
return finished_sending
|
||||
|
||||
def lookup(
|
||||
self,
|
||||
token_len: int,
|
||||
block_hashes: list[BlockHash],
|
||||
use_layerwise: bool,
|
||||
) -> int:
|
||||
"""
|
||||
Checks the existence of KV cache of the tokens from the cache engine.
|
||||
:param tokens: the input tokens, with shape [seq_len]
|
||||
:return: An int indicating how many prefix tokens are cached.
|
||||
"""
|
||||
end = 0
|
||||
keys = []
|
||||
try:
|
||||
starts = []
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
token_len, block_hashes):
|
||||
if use_layerwise:
|
||||
keys_multi_layer = key.split_layers(self.num_layers)
|
||||
for item in keys_multi_layer:
|
||||
keys.append(item.to_string())
|
||||
else:
|
||||
keys.append(key.to_string())
|
||||
starts.append(start)
|
||||
|
||||
res = self.m_store.exists(keys) # type: ignore[assignment]
|
||||
|
||||
if use_layerwise:
|
||||
res = self.check_all_layers_exists(res, self.num_layers)
|
||||
for index, value in enumerate(res): # type: ignore[arg-type]
|
||||
if value != 1:
|
||||
return starts[index]
|
||||
# all tokens where found, return the maximal end
|
||||
except Exception as e:
|
||||
logger.error(f"Remote connection failed in contains: {e}")
|
||||
return start
|
||||
return end
|
||||
|
||||
def lookup_scheduler(
|
||||
self,
|
||||
token_len: int,
|
||||
block_hashes: list[BlockHash],
|
||||
use_layerwise: bool,
|
||||
) -> int:
|
||||
"""
|
||||
Checks the existence of KV cache of the tokens from the cache engine.
|
||||
:param tokens: the input tokens, with shape [seq_len]
|
||||
:return: An int indicating how many prefix tokens are cached.
|
||||
"""
|
||||
end = 0
|
||||
keys = []
|
||||
try:
|
||||
starts = []
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
token_len, block_hashes):
|
||||
if use_layerwise:
|
||||
keys_multi_layer = key.split_layers(self.num_layers)
|
||||
for item in keys_multi_layer:
|
||||
keys.append(item.to_string())
|
||||
else:
|
||||
keys.append(key.to_string())
|
||||
starts.append(start)
|
||||
|
||||
multi_tp_keys = keys[:]
|
||||
for i in range(1, min(self.tp_size, self.num_kv_head)):
|
||||
for item in keys:
|
||||
new_str = item.replace( # type: ignore[attr-defined]
|
||||
"@head_or_tp_rank:0", f"@head_or_tp_rank:{i}", 1)
|
||||
multi_tp_keys.append(new_str)
|
||||
|
||||
for i in range(1, self.pp_size):
|
||||
for item in keys:
|
||||
new_str = item.replace( # type: ignore[attr-defined]
|
||||
"@pp_rank:0", f"@pp_rank:{i}", 1)
|
||||
multi_tp_keys.append(new_str)
|
||||
|
||||
res = self.m_store.exists(
|
||||
multi_tp_keys) # type: ignore[assignment]
|
||||
num_block = len(keys)
|
||||
if use_layerwise:
|
||||
res = self.check_all_layers_exists(res, self.num_layers)
|
||||
num_block = len(keys) // self.num_layers
|
||||
multi_tp_values = [
|
||||
res[i * num_block:(i + 1) * num_block] # type: ignore[index]
|
||||
for i in range(
|
||||
min(self.tp_size, self.num_kv_head) * self.pp_size)
|
||||
]
|
||||
index = self.find_min_first_non_one_index(multi_tp_values)
|
||||
if index != -1:
|
||||
return starts[index]
|
||||
# all tokens where found, return the maximal end
|
||||
except Exception as e:
|
||||
logger.error(f"Remote connection failed in contains: {e}")
|
||||
return start
|
||||
return end
|
||||
|
||||
def check_all_layers_exists(self, res: list[int],
|
||||
num_layers: int) -> list[int]:
|
||||
total_chunks = len(res) // num_layers
|
||||
result = []
|
||||
|
||||
for chunk_idx in range(total_chunks):
|
||||
start = chunk_idx * num_layers
|
||||
end = start + num_layers
|
||||
chunk = res[start:end]
|
||||
result.append(1 if all(x == 1 for x in chunk) else 0)
|
||||
|
||||
return result
|
||||
|
||||
def find_min_first_non_one_index(self, arr):
|
||||
try:
|
||||
return min(idx for row in arr for idx, val in enumerate(row)
|
||||
if val != 1)
|
||||
except ValueError:
|
||||
return -1
|
||||
@@ -0,0 +1,203 @@
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
|
||||
from vllm.logger import logger
|
||||
from vllm.utils.hashing import sha256
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
|
||||
PrefixCachingMetrics)
|
||||
from vllm.v1.core.single_type_kv_cache_manager import \
|
||||
get_manager_for_kv_cache_spec
|
||||
from vllm.v1.kv_cache_interface import KVCacheSpec
|
||||
from vllm.v1.metrics.stats import PrefixCacheStats
|
||||
from vllm.v1.request import Request
|
||||
|
||||
|
||||
class CPUCacheStats:
|
||||
|
||||
def __init__(self, enable_prefix_caching: bool, log_stats: bool = False):
|
||||
self.enable_prefix_caching = enable_prefix_caching
|
||||
self.log_stats = log_stats
|
||||
self.prefix_cache_stats = PrefixCacheStats() if log_stats else None
|
||||
self.cpu_prefix_cache_metrics = PrefixCachingMetrics()
|
||||
self.time_sec = int(time.time())
|
||||
|
||||
def log(self):
|
||||
current_time_sec = int(time.time())
|
||||
# Log the prefix cache hit rate every 10 seconds.
|
||||
if current_time_sec - self.time_sec >= 10:
|
||||
self.time_sec = current_time_sec
|
||||
logger.info("CPU Prefix cache hit rate: %.1f%%",
|
||||
self.cpu_prefix_cache_metrics.hit_rate * 100)
|
||||
|
||||
def make_prefix_cache_stats(self) -> Optional[PrefixCacheStats]:
|
||||
"""Get (and reset) the prefix cache stats.
|
||||
Returns:
|
||||
The current prefix caching stats, or None if logging is disabled.
|
||||
"""
|
||||
if not self.log_stats:
|
||||
return None
|
||||
stats = self.prefix_cache_stats
|
||||
self.prefix_cache_stats = PrefixCacheStats()
|
||||
return stats
|
||||
|
||||
def update(self, num_tokens, num_computed_tokens):
|
||||
# Note the function is called by scheduler
|
||||
if self.log_stats and self.enable_prefix_caching:
|
||||
assert self.prefix_cache_stats is not None
|
||||
self.prefix_cache_stats.requests += 1
|
||||
self.prefix_cache_stats.queries += num_tokens
|
||||
self.prefix_cache_stats.hits += num_computed_tokens
|
||||
|
||||
def set_cache_stats(self, num_tokens, num_computed_tokens):
|
||||
assert self.prefix_cache_stats is not None
|
||||
self.prefix_cache_stats.hits = num_computed_tokens
|
||||
self.prefix_cache_stats.queries = num_tokens
|
||||
self.prefix_cache_stats.requests = 1
|
||||
|
||||
|
||||
class CPUKVCacheManager:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
num_cpu_blocks: int,
|
||||
caching_hash_algo: str = "builtin",
|
||||
use_eagle: bool = False,
|
||||
enable_kv_cache_events: bool = False,
|
||||
) -> None:
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
self.num_cpu_blocks = num_cpu_blocks
|
||||
self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash
|
||||
self.use_eagle = use_eagle
|
||||
self.block_pool = BlockPool(self.num_cpu_blocks, True,
|
||||
enable_kv_cache_events)
|
||||
self.single_type_manager = get_manager_for_kv_cache_spec(
|
||||
kv_cache_spec=kv_cache_spec,
|
||||
block_pool=self.block_pool,
|
||||
kv_cache_group_id=0,
|
||||
)
|
||||
# Record kv block hashes, avoid redundant computation.
|
||||
self.req_to_block_hashes: defaultdict[
|
||||
str, list[BlockHash]] = defaultdict(list)
|
||||
# Record blocks touched in get_matched_num_and_touch().
|
||||
self.req_to_computed_blocks: defaultdict[
|
||||
str, list[KVCacheBlock]] = defaultdict(list)
|
||||
# Record the request that failed to allocate.
|
||||
self.req_failed_to_allocate: defaultdict[str, bool] = defaultdict(bool)
|
||||
self.req_to_num_tokens: defaultdict[str, int] = defaultdict(int)
|
||||
self.cpu_cache_stats = CPUCacheStats(enable_prefix_caching=True,
|
||||
log_stats=True)
|
||||
# Record request that will be free after finish sending
|
||||
self.req_to_free: defaultdict[str, Request] = defaultdict(Request)
|
||||
|
||||
def get_matched_num_and_touch(self, request: Request) -> tuple[int, bool]:
|
||||
# When the request requires prompt logprobs, we skip prefix caching.
|
||||
if (request.sampling_params.prompt_logprobs is not None):
|
||||
return 0, False
|
||||
request_id = request.request_id
|
||||
# The block hashes for the request may already be computed
|
||||
# if the scheduler has tried to schedule the request before.
|
||||
block_hashes = self.req_to_block_hashes[request_id]
|
||||
if not block_hashes:
|
||||
block_hashes = request.block_hashes
|
||||
self.req_to_block_hashes[request_id] = block_hashes
|
||||
max_cache_hit_length = request.num_tokens - 1
|
||||
computed_blocks = self.single_type_manager.find_longest_cache_hit(
|
||||
block_hashes=block_hashes,
|
||||
max_length=max_cache_hit_length,
|
||||
kv_cache_group_ids=[0],
|
||||
block_pool=self.block_pool,
|
||||
kv_cache_spec=self.single_type_manager.kv_cache_spec,
|
||||
use_eagle=self.use_eagle,
|
||||
)
|
||||
num_computed_tokens = len(computed_blocks[0]) * self.block_size
|
||||
self.req_to_computed_blocks[request_id] = computed_blocks[0]
|
||||
# We should touch these blocks in the concurrent scenarios.
|
||||
self.block_pool.touch(computed_blocks)
|
||||
|
||||
# cup prefix cache status set and log
|
||||
assert self.cpu_cache_stats is not None and self.cpu_cache_stats.prefix_cache_stats is not None
|
||||
self.cpu_cache_stats.set_cache_stats(request.num_tokens,
|
||||
num_computed_tokens)
|
||||
self.cpu_cache_stats.cpu_prefix_cache_metrics.observe(
|
||||
self.cpu_cache_stats.prefix_cache_stats)
|
||||
self.cpu_cache_stats.log()
|
||||
|
||||
return num_computed_tokens, False
|
||||
|
||||
def _release_ahead_touch(self, request_id: str):
|
||||
computed_blocks = self.req_to_computed_blocks[request_id]
|
||||
if computed_blocks:
|
||||
self.single_type_manager.block_pool.free_blocks(
|
||||
reversed(computed_blocks))
|
||||
self.req_to_computed_blocks.pop(request_id, None)
|
||||
|
||||
def allocate_slots(self, req_to_num_tokens: dict[str, int],
|
||||
unallocated_req_ids: set[str]) -> dict[str, list[int]]:
|
||||
for request_id in unallocated_req_ids:
|
||||
self._free_slots(request_id)
|
||||
req_to_new_blocks = {}
|
||||
for request_id, num_tokens in req_to_num_tokens.items():
|
||||
if self.req_failed_to_allocate[request_id]:
|
||||
continue
|
||||
new_computed_blocks = self.req_to_computed_blocks[request_id]
|
||||
num_blocks_to_allocate = (
|
||||
self.single_type_manager.get_num_blocks_to_allocate(
|
||||
request_id=request_id,
|
||||
num_tokens=num_tokens,
|
||||
new_computed_blocks=new_computed_blocks,
|
||||
))
|
||||
if num_blocks_to_allocate > self.block_pool.get_num_free_blocks():
|
||||
self._release_ahead_touch(request_id)
|
||||
self.req_failed_to_allocate[request_id] = True
|
||||
continue
|
||||
# Append the new computed blocks to the request blocks until now to
|
||||
# avoid the case where the new blocks cannot be allocated.
|
||||
self.single_type_manager.save_new_computed_blocks(
|
||||
request_id, new_computed_blocks)
|
||||
# Allocate new blocks but do not cache now.
|
||||
new_blocks = self.single_type_manager.allocate_new_blocks(
|
||||
request_id, num_tokens)
|
||||
self.req_to_num_tokens[request_id] = num_tokens
|
||||
# No need to release ref_cnt because we use officially.
|
||||
self.req_to_computed_blocks.pop(request_id, None)
|
||||
req_to_new_blocks[request_id] = [
|
||||
block.block_id for block in new_computed_blocks + new_blocks
|
||||
]
|
||||
return req_to_new_blocks
|
||||
|
||||
def record_request_cache_and_free_slots(self, request: Request):
|
||||
logger.debug(
|
||||
f"record_request_cache_and_free_slots for request {request.request_id} in cpu_kv_cache_manager"
|
||||
)
|
||||
self.req_to_free[request.request_id] = request
|
||||
|
||||
def cache_and_free_slots(self, request_id: str):
|
||||
logger.debug(
|
||||
f"Cache and free slots for request {request_id} in cpu_kv_cache_manager"
|
||||
)
|
||||
if request_id not in self.req_to_free:
|
||||
logger.Error(
|
||||
f"request {request_id} not in req_to_free, maybe bug!")
|
||||
return
|
||||
request = self.req_to_free[request_id]
|
||||
if not self.req_failed_to_allocate[request_id]:
|
||||
self.single_type_manager.cache_blocks(
|
||||
request,
|
||||
self.req_to_num_tokens[request_id],
|
||||
)
|
||||
self._free_slots(request_id)
|
||||
logger.debug(
|
||||
f"delete request {request_id} in cpu_kv_cache_manager req_to_free")
|
||||
del self.req_to_free[request_id]
|
||||
|
||||
def _free_slots(self, request_id: str):
|
||||
# This function is designed to be reentrant.
|
||||
self._release_ahead_touch(request_id)
|
||||
self.single_type_manager.free(request_id)
|
||||
self.req_to_block_hashes.pop(request_id, None)
|
||||
self.req_to_computed_blocks.pop(request_id, None)
|
||||
self.req_failed_to_allocate.pop(request_id, None)
|
||||
self.req_to_num_tokens.pop(request_id, None)
|
||||
@@ -0,0 +1,528 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional, Sequence
|
||||
|
||||
import torch
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.attention.layer import Attention, MLAAttention
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
|
||||
from vllm.logger import logger
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec,
|
||||
MambaSpec, MLAAttentionSpec)
|
||||
|
||||
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.metadata import (
|
||||
MetadataServer, MetadataServerProc, MLAConfig)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReqMeta:
|
||||
gpu_block_ids: list[int]
|
||||
cpu_block_ids: list[int]
|
||||
num_scheduled_tokens: int
|
||||
num_computed_tokens: int
|
||||
num_gpu_computed_tokens: int
|
||||
num_cpu_computed_tokens: int
|
||||
|
||||
def update(self, other: "ReqMeta"):
|
||||
self.gpu_block_ids.extend(other.gpu_block_ids)
|
||||
self.cpu_block_ids.extend(other.cpu_block_ids)
|
||||
self.num_scheduled_tokens = other.num_scheduled_tokens
|
||||
self.num_computed_tokens = other.num_computed_tokens
|
||||
self.num_gpu_computed_tokens = other.num_gpu_computed_tokens
|
||||
self.num_cpu_computed_tokens = other.num_cpu_computed_tokens
|
||||
|
||||
|
||||
@dataclass
|
||||
class CPUOffloadingConnectorMetadata(KVConnectorMetadata):
|
||||
requests: dict[str, ReqMeta]
|
||||
finished_req_ids: set[str]
|
||||
|
||||
|
||||
class CPUOffloadingConnector(KVConnectorBase_V1):
|
||||
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: Optional[KVCacheConfig] = None):
|
||||
if not vllm_config.cache_config.enable_prefix_caching:
|
||||
self.connector_scheduler: Optional[
|
||||
CPUOffloadingConnectorScheduler] = None
|
||||
self.connector_worker: Optional[
|
||||
CPUOffloadingConnectorWorker] = None
|
||||
elif role == KVConnectorRole.SCHEDULER:
|
||||
self.connector_scheduler = CPUOffloadingConnectorScheduler(
|
||||
vllm_config)
|
||||
self.connector_worker = None
|
||||
elif role == KVConnectorRole.WORKER:
|
||||
self.connector_scheduler = None
|
||||
self.connector_worker = CPUOffloadingConnectorWorker(vllm_config)
|
||||
|
||||
# ==============================
|
||||
# Worker-side methods
|
||||
# ==============================
|
||||
|
||||
def bind_connector_metadata(
|
||||
self, connector_metadata: KVConnectorMetadata) -> None:
|
||||
if self.connector_worker is not None:
|
||||
assert isinstance(connector_metadata,
|
||||
CPUOffloadingConnectorMetadata)
|
||||
self.connector_worker.bind_connector_metadata(connector_metadata)
|
||||
|
||||
def clear_connector_metadata(self) -> None:
|
||||
assert self.connector_worker is not None
|
||||
self.connector_worker.clear_connector_metadata()
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
if self.connector_worker is not None:
|
||||
self.connector_worker.register_kv_caches(kv_caches)
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext",
|
||||
**kwargs) -> None:
|
||||
if self.connector_worker is not None:
|
||||
self.connector_worker.start_load_kv()
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
if self.connector_worker is not None:
|
||||
self.connector_worker.wait_for_layer_load()
|
||||
|
||||
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
||||
attn_metadata: "AttentionMetadata", **kwargs) -> None:
|
||||
pass
|
||||
|
||||
def wait_for_save(self):
|
||||
pass
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str]
|
||||
) -> tuple[Optional[set[str]], Optional[set[str]]]:
|
||||
assert self.connector_worker is not None
|
||||
return self.connector_worker.get_finished(), None
|
||||
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request",
|
||||
num_computed_tokens: int) -> tuple[int, bool]:
|
||||
if self.connector_scheduler is not None:
|
||||
return self.connector_scheduler.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens)
|
||||
return 0, False
|
||||
|
||||
def update_state_after_alloc(self, request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int):
|
||||
if self.connector_scheduler is not None:
|
||||
return self.connector_scheduler.update_state_after_alloc(request)
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
|
||||
if self.connector_scheduler is not None:
|
||||
return self.connector_scheduler.build_connector_meta(
|
||||
scheduler_output)
|
||||
return KVConnectorMetadata()
|
||||
|
||||
def request_finished(
|
||||
self, request: "Request",
|
||||
block_ids: list[int]) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||
if self.connector_scheduler is not None:
|
||||
self.connector_scheduler.request_finished(request)
|
||||
return True, None
|
||||
|
||||
|
||||
class CPUOffloadingConnectorScheduler:
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
logger.info("init CPUOffloadingConnectorScheduler")
|
||||
self.vllm_config = vllm_config
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.use_mla = vllm_config.model_config.use_mla
|
||||
self.num_gpu_computed_tokens: dict[str, int] = {}
|
||||
self.num_cpu_computed_tokens: dict[str, int] = {}
|
||||
self.allocated_req_ids: set[str] = set()
|
||||
self.finished_req_ids: list[str] = []
|
||||
self.zmq_rpc_client = MetadataServer.ZMQRPCClient()
|
||||
self.zmq_rpc_client.call("post_init")
|
||||
if vllm_config.kv_transfer_config is not None:
|
||||
self.swap_in_threshold = vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
"swap_in_threshold", 0)
|
||||
else:
|
||||
self.swap_in_threshold = 0
|
||||
logger.info(f"swap_in_threshold: {self.swap_in_threshold}")
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, ori_request: "Request",
|
||||
num_computed_tokens: int) -> tuple[int, bool]:
|
||||
request = copy.deepcopy(ori_request)
|
||||
request.get_hash_new_full_blocks = None
|
||||
num_cpu_computed_tokens, load_async = self.zmq_rpc_client.call(
|
||||
"get_matched_num_and_touch", request)
|
||||
self.num_gpu_computed_tokens[request.request_id] = num_computed_tokens
|
||||
self.num_cpu_computed_tokens[
|
||||
request.request_id] = num_cpu_computed_tokens
|
||||
if num_cpu_computed_tokens - num_computed_tokens >= self.swap_in_threshold:
|
||||
return num_cpu_computed_tokens - num_computed_tokens, load_async
|
||||
else:
|
||||
return 0, load_async
|
||||
|
||||
def update_state_after_alloc(self, request: "Request"):
|
||||
self.allocated_req_ids.add(request.request_id)
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
|
||||
num_tokens = {}
|
||||
# process scheduled_new_reqs
|
||||
for req in scheduler_output.scheduled_new_reqs:
|
||||
req_id = req.req_id
|
||||
num_tokens[req_id] = (
|
||||
req.num_computed_tokens +
|
||||
scheduler_output.num_scheduled_tokens[req_id])
|
||||
|
||||
# process scheduled_cached_reqs
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
for idx, req_id in enumerate(cached_reqs.req_ids):
|
||||
num_tokens[req_id] = (
|
||||
cached_reqs.num_computed_tokens[idx] +
|
||||
scheduler_output.num_scheduled_tokens[req_id])
|
||||
|
||||
unallocated_req_ids = set(self.num_gpu_computed_tokens.keys() -
|
||||
self.allocated_req_ids -
|
||||
scheduler_output.num_scheduled_tokens.keys())
|
||||
new_cpu_block_ids = self.zmq_rpc_client.call("allocate_slots",
|
||||
num_tokens,
|
||||
unallocated_req_ids)
|
||||
metadata = CPUOffloadingConnectorMetadata(
|
||||
requests={},
|
||||
finished_req_ids=set(self.finished_req_ids),
|
||||
)
|
||||
for req in scheduler_output.scheduled_new_reqs:
|
||||
req_id = req.req_id
|
||||
gpu_block_ids = req.block_ids[0]
|
||||
metadata.requests[req_id] = ReqMeta(
|
||||
gpu_block_ids=[] if gpu_block_ids is None else gpu_block_ids,
|
||||
cpu_block_ids=new_cpu_block_ids.get(req_id, []),
|
||||
num_scheduled_tokens=scheduler_output.
|
||||
num_scheduled_tokens[req_id],
|
||||
num_computed_tokens=req.num_computed_tokens,
|
||||
num_gpu_computed_tokens=self.num_gpu_computed_tokens[req_id],
|
||||
num_cpu_computed_tokens=self.num_cpu_computed_tokens[req_id])
|
||||
|
||||
for idx, req_id in enumerate(cached_reqs.req_ids):
|
||||
gpu_block_ids = cached_reqs.new_block_ids[idx]
|
||||
metadata.requests[req_id] = ReqMeta(
|
||||
gpu_block_ids=[] if gpu_block_ids is None else gpu_block_ids,
|
||||
cpu_block_ids=new_cpu_block_ids.get(req_id, []),
|
||||
num_scheduled_tokens=scheduler_output.
|
||||
num_scheduled_tokens[req_id],
|
||||
num_computed_tokens=cached_reqs.num_computed_tokens[idx],
|
||||
num_gpu_computed_tokens=cached_reqs.num_computed_tokens[idx],
|
||||
num_cpu_computed_tokens=cached_reqs.num_computed_tokens[idx])
|
||||
self.num_gpu_computed_tokens.clear()
|
||||
self.num_cpu_computed_tokens.clear()
|
||||
self.allocated_req_ids.clear()
|
||||
self.finished_req_ids.clear()
|
||||
return metadata
|
||||
|
||||
def request_finished(self, ori_request: "Request"):
|
||||
request = copy.deepcopy(ori_request)
|
||||
request.get_hash_new_full_blocks = None
|
||||
self.finished_req_ids.append(request.request_id)
|
||||
# inform metadata server to record request, and free it after finish sending
|
||||
self.zmq_rpc_client.call("record_request_cache_and_free_slots",
|
||||
request)
|
||||
|
||||
|
||||
class CPUOffloadingConnectorWorker:
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
logger.info("init CPUOffloadingConnectorWorker")
|
||||
self.vllm_config = vllm_config
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.pp_rank = get_pp_group().rank_in_group
|
||||
self.tp_group = get_tp_group()
|
||||
self.tp_rank = self.tp_group.rank_in_group
|
||||
self.tp_world_size = self.tp_group.world_size
|
||||
self.use_mla = vllm_config.model_config.use_mla
|
||||
|
||||
self.requests: dict[str, ReqMeta] = {}
|
||||
self.load_stream = torch.npu.Stream()
|
||||
self.save_stream = torch.npu.Stream()
|
||||
self.zmq_rpc_client = MetadataServer.ZMQRPCClient()
|
||||
self.load_block_mapping: list[tuple[int, int]] = []
|
||||
self.save_input_queue: queue.Queue[tuple[str, ReqMeta]] = queue.Queue()
|
||||
self.save_output_queue: queue.Queue[str] = queue.Queue()
|
||||
self.save_thread = threading.Thread(target=self._save_listener)
|
||||
self.save_thread.start()
|
||||
self.done_sending_count: defaultdict[str, int] = defaultdict(int)
|
||||
|
||||
# start metadata server to init cpu_kv_cache_manager and handle rpc requests
|
||||
# all dp shared the same metadata server, only start the process on data_rank 0
|
||||
if vllm_config.parallel_config.data_parallel_rank == 0 and self.tp_rank == 0 and self.pp_rank == 0:
|
||||
config = VllmConfig()
|
||||
config.cache_config = vllm_config.cache_config
|
||||
config.parallel_config = vllm_config.parallel_config
|
||||
config.kv_transfer_config = vllm_config.kv_transfer_config
|
||||
self.init_metadata_server(config)
|
||||
self._wait_for_metadata_process_start()
|
||||
|
||||
def init_metadata_server(self, vllm_config: VllmConfig):
|
||||
self.metadata_thread = threading.Thread(
|
||||
target=MetadataServerProc.run_metadata_server,
|
||||
args=(vllm_config, ),
|
||||
)
|
||||
self.metadata_thread.daemon = True
|
||||
self.metadata_thread.start()
|
||||
|
||||
def _wait_for_metadata_process_start(self):
|
||||
# TODO: wait for metadata server to start, add a rpc to check if ready
|
||||
while True:
|
||||
try:
|
||||
if self.zmq_rpc_client.call("ready"):
|
||||
break
|
||||
except Exception as e:
|
||||
logger.info(f"wait for metadata server to start, error: {e}")
|
||||
time.sleep(1)
|
||||
|
||||
def bind_connector_metadata(
|
||||
self, connector_metadata: CPUOffloadingConnectorMetadata) -> None:
|
||||
for req_id, req in connector_metadata.requests.items():
|
||||
if req_id in self.requests:
|
||||
self.requests[req_id].update(req)
|
||||
req = self.requests[req_id]
|
||||
else:
|
||||
self.requests[req_id] = req
|
||||
for i in range(req.num_gpu_computed_tokens // self.block_size,
|
||||
req.num_computed_tokens // self.block_size):
|
||||
self.load_block_mapping.append(
|
||||
(req.cpu_block_ids[i], req.gpu_block_ids[i]))
|
||||
for req_id in connector_metadata.finished_req_ids:
|
||||
if req_id in self.requests:
|
||||
self.save_input_queue.put((req_id, self.requests[req_id]))
|
||||
|
||||
def clear_connector_metadata(self) -> None:
|
||||
self.load_block_mapping.clear()
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, Sequence[torch.Tensor]]):
|
||||
self.gpu_kv_caches = kv_caches
|
||||
model_config = self.vllm_config.model_config
|
||||
mla_config: Optional[MLAConfig] = None
|
||||
if model_config.use_mla:
|
||||
mla_config = MLAConfig(
|
||||
model_config.hf_text_config.kv_lora_rank,
|
||||
model_config.hf_text_config.qk_rope_head_dim)
|
||||
self.cpu_kv_caches = list(
|
||||
self.zmq_rpc_client.call(
|
||||
"init_cpu_kv_caches",
|
||||
self.pp_rank,
|
||||
self.tp_rank,
|
||||
get_kv_cache_spec(self.vllm_config),
|
||||
mla_config,
|
||||
).values())
|
||||
|
||||
def start_load_kv(self) -> None:
|
||||
self.current_layer = 0
|
||||
self.gpu_kv_caches_load_iter = iter(self.gpu_kv_caches.values())
|
||||
self.load_kv_layer(0)
|
||||
|
||||
def wait_for_layer_load(self) -> None:
|
||||
# TODO: Replace with `torch.npu.current_stream().wait_stream(self.load_stream)` after fixing the bug.
|
||||
self.load_stream.synchronize()
|
||||
self.current_layer += 1
|
||||
self.load_kv_layer(self.current_layer)
|
||||
|
||||
def load_kv_layer(self, layer: int):
|
||||
if layer == len(self.gpu_kv_caches):
|
||||
return
|
||||
gpu_kv_caches = next(self.gpu_kv_caches_load_iter)
|
||||
cpu_kv_caches = self.cpu_kv_caches[layer]
|
||||
with torch.npu.stream(self.load_stream):
|
||||
for cpu_block_id, gpu_block_id in self.load_block_mapping:
|
||||
for gpu_layer_part, cpu_layer_part in zip(
|
||||
gpu_kv_caches, cpu_kv_caches):
|
||||
gpu_layer_part[gpu_block_id].copy_(
|
||||
cpu_layer_part[cpu_block_id], non_blocking=True)
|
||||
|
||||
def get_finished(self) -> set[str]:
|
||||
done_sending: set[str] = set()
|
||||
while True:
|
||||
try:
|
||||
id = self.save_output_queue.get_nowait()
|
||||
except queue.Empty:
|
||||
break
|
||||
done_sending.add(id)
|
||||
for id in done_sending:
|
||||
del self.requests[id]
|
||||
if self.tp_world_size == 1:
|
||||
return done_sending
|
||||
if self.tp_rank == 0:
|
||||
for req_id in done_sending:
|
||||
self.done_sending_count[req_id] += 1
|
||||
other_ranks_finished_ids: list[str] = []
|
||||
for i in range(1, self.tp_world_size):
|
||||
other_ranks_finished_ids.extend(
|
||||
self.tp_group.recv_object(src=i))
|
||||
for req_id in other_ranks_finished_ids:
|
||||
self.done_sending_count[req_id] += 1
|
||||
all_done_sending: set[str] = set()
|
||||
for req_id in list(self.done_sending_count.keys()):
|
||||
if self.done_sending_count[req_id] == self.tp_world_size:
|
||||
del self.done_sending_count[req_id]
|
||||
all_done_sending.add(req_id)
|
||||
# release cpu_kv_cache after request sending finished
|
||||
# to avoid rpc blocking, use thread to call rpc asynchronously
|
||||
sending_finished_thread = threading.Thread(
|
||||
target=self._sending_finished, args=(all_done_sending, ))
|
||||
sending_finished_thread.daemon = True
|
||||
sending_finished_thread.start()
|
||||
|
||||
return all_done_sending
|
||||
else:
|
||||
self.tp_group.send_object(done_sending, dst=0)
|
||||
return done_sending
|
||||
|
||||
def _sending_finished(self, all_done_sending):
|
||||
for req_id in all_done_sending:
|
||||
logger.debug(f"call cache_and_free_slots for req_id: {req_id}")
|
||||
self.zmq_rpc_client.call("cache_and_free_slots", req_id)
|
||||
|
||||
def _save_listener(self):
|
||||
save_block_mapping = []
|
||||
while True:
|
||||
req_id, req = self.save_input_queue.get()
|
||||
for i in range(
|
||||
req.num_cpu_computed_tokens // self.block_size,
|
||||
min((req.num_computed_tokens + req.num_scheduled_tokens) //
|
||||
self.block_size, len(req.cpu_block_ids))):
|
||||
save_block_mapping.append(
|
||||
(req.gpu_block_ids[i], req.cpu_block_ids[i]))
|
||||
with torch.npu.stream(self.save_stream):
|
||||
# MLA: kv_layer is tuple[tensor, tensor] means (rope, nope).
|
||||
# non-MLA: kv_layer is list[tensor], typically means [k, v].
|
||||
if self.use_mla:
|
||||
start, step = self.tp_rank, self.tp_world_size
|
||||
else:
|
||||
start, step = 0, 1
|
||||
for i in range(start, len(save_block_mapping), step):
|
||||
gpu_block_id, cpu_block_id = save_block_mapping[i]
|
||||
for cpu_kv_caches, gpu_kv_caches in zip(
|
||||
self.cpu_kv_caches, self.gpu_kv_caches.values()):
|
||||
for cpu_layer_part, gpu_layer_part in zip(
|
||||
cpu_kv_caches, gpu_kv_caches):
|
||||
cpu_layer_part[cpu_block_id].copy_(
|
||||
gpu_layer_part[gpu_block_id],
|
||||
non_blocking=True)
|
||||
self.save_stream.synchronize()
|
||||
self.save_output_queue.put(req_id)
|
||||
save_block_mapping.clear()
|
||||
|
||||
|
||||
# copied and modified from vllm_ascend/worker/model_runner_v1.py
|
||||
def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
|
||||
"""
|
||||
Generates the KVCacheSpec by parsing the kv cache format from each
|
||||
Attention module in the static forward context.
|
||||
Returns:
|
||||
KVCacheSpec: A dictionary mapping layer names to their KV cache
|
||||
format. Layers that do not need KV cache are not included.
|
||||
"""
|
||||
if has_ec_transfer() and get_ec_transfer().is_producer:
|
||||
return {}
|
||||
|
||||
block_size = vllm_config.cache_config.block_size
|
||||
use_mla = vllm_config.model_config.use_mla
|
||||
use_sparse = hasattr(vllm_config.model_config.hf_config, "index_topk")
|
||||
if vllm_config.cache_config.cache_dtype == "auto":
|
||||
kv_cache_dtype = vllm_config.model_config.dtype
|
||||
else:
|
||||
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
||||
vllm_config.cache_config.cache_dtype]
|
||||
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||
attn_layers = get_layers_from_vllm_config(vllm_config, AttentionLayerBase)
|
||||
for layer_name, attn_module in attn_layers.items():
|
||||
if isinstance(attn_module, Attention):
|
||||
# TODO: Support other attention modules, e.g., cross-attention
|
||||
# TODO(lucas): move the attention specs into the model layers like
|
||||
# the attention backends
|
||||
if attn_module.attn_type == AttentionType.DECODER:
|
||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=kv_cache_dtype)
|
||||
elif attn_module.attn_type in (AttentionType.ENCODER,
|
||||
AttentionType.ENCODER_ONLY):
|
||||
# encoder-only attention does not need KV cache.
|
||||
continue
|
||||
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown attention type: {attn_module.attn_type}")
|
||||
|
||||
elif isinstance(attn_module, MLAAttention):
|
||||
if use_mla and not use_sparse:
|
||||
kv_cache_spec[layer_name] = MLAAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=kv_cache_dtype,
|
||||
cache_dtype_str=vllm_config.cache_config.cache_dtype)
|
||||
else:
|
||||
# TODO(cmq): This is a hack way to fix deepseek kvcache when
|
||||
# using DSA. Fix the spec in vLLM is a finnal way.
|
||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=kv_cache_dtype)
|
||||
|
||||
mamba_layers = get_layers_from_vllm_config(vllm_config, MambaBase)
|
||||
if len(mamba_layers) > 0:
|
||||
if (vllm_config.speculative_config is not None
|
||||
and vllm_config.model_config.hf_config.model_type
|
||||
not in ["qwen3_next"]):
|
||||
raise NotImplementedError(
|
||||
"Mamba with speculative decoding is not supported yet.")
|
||||
if vllm_config.cache_config.enable_prefix_caching:
|
||||
raise NotImplementedError(
|
||||
"Prefix caching is not supported for Mamba yet.")
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
|
||||
page_size_padded = (vllm_config.cache_config.mamba_page_size_padded)
|
||||
|
||||
# Set block_size to max_model_len, so that mamba model will always
|
||||
# have only one block in the KV cache.
|
||||
for layer_name, mamba_module in mamba_layers.items():
|
||||
kv_cache_spec[layer_name] = MambaSpec(
|
||||
shapes=mamba_module.get_state_shape(),
|
||||
dtypes=mamba_module.get_state_dtype(),
|
||||
block_size=max_model_len,
|
||||
page_size_padded=page_size_padded,
|
||||
mamba_type=mamba_module.mamba_type,
|
||||
num_speculative_blocks=(
|
||||
vllm_config.speculative_config.num_speculative_tokens
|
||||
if vllm_config.speculative_config else 0),
|
||||
)
|
||||
|
||||
return kv_cache_spec
|
||||
246
vllm_ascend/distributed/kv_transfer/ucm_connector.py
Normal file
246
vllm_ascend/distributed/kv_transfer/ucm_connector.py
Normal file
@@ -0,0 +1,246 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
from ucm.integration.vllm.ucm_connector import UCMConnector
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||
KVConnectorPromMetrics, KVConnectorStats, PromMetric, PromMetricT)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
|
||||
|
||||
class UCMConnectorV1(KVConnectorBase_V1):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: "KVCacheConfig",
|
||||
):
|
||||
super().__init__(vllm_config=vllm_config,
|
||||
role=role,
|
||||
kv_cache_config=kv_cache_config)
|
||||
assert vllm_config.kv_transfer_config is not None
|
||||
|
||||
ImplCls = UCMConnector
|
||||
self._ucm_engine = ImplCls(vllm_config, role)
|
||||
|
||||
# ==============================
|
||||
# Worker-side methods
|
||||
# ==============================
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]) -> None:
|
||||
"""
|
||||
Initialize with the KV caches. Useful for pre-registering the
|
||||
KV Caches in the KVConnector (e.g. for NIXL).
|
||||
Args:
|
||||
kv_caches: A dictionary mapping layer names to KV cache tensors.
|
||||
"""
|
||||
self._ucm_engine.register_kv_caches(kv_caches)
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext",
|
||||
**kwargs: Any) -> None:
|
||||
"""
|
||||
Start loading the KV cache from the connector to vLLM's paged
|
||||
KV buffer. This is called from the forward context before the
|
||||
forward pass to enable async loading during model execution.
|
||||
|
||||
Args:
|
||||
forward_context (ForwardContext): the forward context.
|
||||
**kwargs: additional arguments for the load operation
|
||||
|
||||
Note:
|
||||
The number of elements in kv_caches and layer_names should be
|
||||
the same.
|
||||
|
||||
"""
|
||||
self._ucm_engine.start_load_kv(forward_context, **kwargs)
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
"""
|
||||
Block until the KV for a specific layer is loaded into vLLM's
|
||||
paged buffer. This is called from within attention layer to ensure
|
||||
async copying from start_load_kv is complete.
|
||||
|
||||
This interface will be useful for layer-by-layer pipelining.
|
||||
|
||||
Args:
|
||||
layer_name: the name of that layer
|
||||
"""
|
||||
self._ucm_engine.wait_for_layer_load(layer_name)
|
||||
|
||||
def save_kv_layer(
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer: torch.Tensor,
|
||||
attn_metadata: "AttentionMetadata",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Start saving the a layer of KV cache from vLLM's paged buffer
|
||||
to the connector. This is called from within attention layer to
|
||||
enable async copying during execution.
|
||||
|
||||
Args:
|
||||
layer_name (str): the name of the layer.
|
||||
kv_layer (torch.Tensor): the paged KV buffer of the current
|
||||
layer in vLLM.
|
||||
attn_metadata (AttentionMetadata): the attention metadata.
|
||||
**kwargs: additional arguments for the save operation.
|
||||
"""
|
||||
self._ucm_engine.save_kv_layer(layer_name, kv_layer, attn_metadata,
|
||||
**kwargs)
|
||||
|
||||
def wait_for_save(self) -> None:
|
||||
"""
|
||||
Block until all the save operations is done. This is called
|
||||
as the forward context exits to ensure that the async saving
|
||||
from save_kv_layer is complete before finishing the forward.
|
||||
|
||||
This prevents overwrites of paged KV buffer before saving done.
|
||||
"""
|
||||
self._ucm_engine.wait_for_save()
|
||||
|
||||
def clear_connector_metadata(self) -> None:
|
||||
"""Clear the connector metadata.
|
||||
|
||||
This function should be called by the model runner every time
|
||||
after the model execution.
|
||||
"""
|
||||
self._ucm_engine.clear_connector_metadata()
|
||||
|
||||
def bind_connector_metadata(
|
||||
self, connector_metadata: KVConnectorMetadata) -> None:
|
||||
"""Set the connector metadata from the scheduler.
|
||||
|
||||
This function should be called by the model runner every time
|
||||
before the model execution. The metadata will be used for runtime
|
||||
KV cache loading and saving.
|
||||
|
||||
Args:
|
||||
connector_metadata (dict): the connector metadata.
|
||||
"""
|
||||
self._ucm_engine.bind_connector_metadata(connector_metadata)
|
||||
|
||||
def get_block_ids_with_load_errors(self) -> set[int]:
|
||||
"""
|
||||
Get the set of block IDs that failed to load.
|
||||
|
||||
Returns:
|
||||
Set of block IDs that encountered load errors.
|
||||
Empty set if no load errors occurred.
|
||||
"""
|
||||
return self._ucm_engine.get_block_ids_with_load_errors()
|
||||
|
||||
# ==============================
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int | None, bool]:
|
||||
"""
|
||||
Get number of new tokens that can be loaded from the
|
||||
external KV cache beyond the num_computed_tokens.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
|
||||
Returns:
|
||||
the number of tokens that can be loaded from the
|
||||
external KV cache beyond what is already computed.
|
||||
"""
|
||||
return self._ucm_engine.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens)
|
||||
|
||||
def update_state_after_alloc(self, request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int) -> None:
|
||||
"""
|
||||
Update KVConnector state after block allocation.
|
||||
"""
|
||||
self._ucm_engine.update_state_after_alloc(request, blocks,
|
||||
num_external_tokens)
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
|
||||
"""
|
||||
Build the connector metadata for this step.
|
||||
|
||||
This function should NOT modify fields in the scheduler_output.
|
||||
Also, calling this function will reset the state of the connector.
|
||||
|
||||
Args:
|
||||
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||
"""
|
||||
return self._ucm_engine.build_connector_meta(scheduler_output)
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
"""
|
||||
Called when a request has finished, before its blocks are freed.
|
||||
|
||||
Returns:
|
||||
True if the request is being saved/sent asynchronously and blocks
|
||||
should not be freed until the request_id is returned from
|
||||
get_finished().
|
||||
Optional KVTransferParams to be included in the request outputs
|
||||
returned by the engine.
|
||||
"""
|
||||
return self._ucm_engine.request_finished(request, block_ids)
|
||||
|
||||
# ==============================
|
||||
# Metrics & Stats
|
||||
# ==============================
|
||||
|
||||
@classmethod
|
||||
def build_kv_connector_stats(
|
||||
cls,
|
||||
data: dict[str, Any] | None = None
|
||||
) -> Optional["KVConnectorStats"]:
|
||||
"""
|
||||
KVConnectorStats resolution method. This method allows dynamically
|
||||
registered connectors to return their own KVConnectorStats object,
|
||||
which can implement custom aggregation logic on the data dict.
|
||||
"""
|
||||
return UCMConnector.build_kv_connector_stats(data)
|
||||
|
||||
@classmethod
|
||||
def build_prom_metrics(
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
metric_types: dict[type["PromMetric"], type["PromMetricT"]],
|
||||
labelnames: list[str],
|
||||
per_engine_labelvalues: dict[int, list[object]],
|
||||
) -> Optional["KVConnectorPromMetrics"]:
|
||||
"""
|
||||
Create a KVConnectorPromMetrics subclass which should register
|
||||
per-connector Prometheus metrics and implement observe() to
|
||||
expose connector transfer stats via Prometheus.
|
||||
|
||||
This implementation forwards the call to the underlying
|
||||
UCMConnector engine.
|
||||
"""
|
||||
return UCMConnector.build_prom_metrics(
|
||||
vllm_config,
|
||||
metric_types,
|
||||
labelnames,
|
||||
per_engine_labelvalues,
|
||||
)
|
||||
@@ -0,0 +1,53 @@
|
||||
import ipaddress
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
from mooncake.engine import TransferEngine # type: ignore
|
||||
|
||||
|
||||
class GlobalTE():
|
||||
|
||||
def __init__(self):
|
||||
self.transfer_engine = None
|
||||
self.is_register_buffer: bool = False
|
||||
self.transfer_engine_lock = threading.Lock()
|
||||
self.register_buffer_lock = threading.Lock()
|
||||
|
||||
def get_transfer_engine(self, hostname: str, device_name: Optional[str]):
|
||||
try:
|
||||
ip = ipaddress.ip_address(hostname)
|
||||
if isinstance(ip, ipaddress.IPv6Address):
|
||||
raise RuntimeError(
|
||||
"The backend of mooncake's Ascend Direct Xfer Library currently does not support IPv6."
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
if self.transfer_engine is None:
|
||||
with self.transfer_engine_lock:
|
||||
# Double-Checked Locking
|
||||
if self.transfer_engine is None:
|
||||
if TransferEngine is None:
|
||||
raise RuntimeError("mooncake is not available")
|
||||
self.transfer_engine = TransferEngine()
|
||||
device_name = device_name if device_name is not None else ""
|
||||
ret_value = self.transfer_engine.initialize(
|
||||
hostname, "P2PHANDSHAKE", "ascend", device_name)
|
||||
if ret_value != 0:
|
||||
raise RuntimeError(
|
||||
f"TransferEngine initialization failed with ret_value: {ret_value}"
|
||||
)
|
||||
return self.transfer_engine
|
||||
|
||||
def register_buffer(self, ptrs: list[int], sizes: list[int]):
|
||||
with self.register_buffer_lock:
|
||||
assert self.transfer_engine is not None, "Transfer engine must be initialized"
|
||||
if self.is_register_buffer:
|
||||
return
|
||||
for ptr, size in zip(ptrs, sizes):
|
||||
ret_value = self.transfer_engine.register_memory(ptr, size)
|
||||
if ret_value != 0:
|
||||
raise RuntimeError("Mooncake memory registration failed.")
|
||||
self.is_register_buffer = True
|
||||
|
||||
|
||||
global_te = GlobalTE()
|
||||
61
vllm_ascend/distributed/kv_transfer/utils/utils.py
Normal file
61
vllm_ascend/distributed/kv_transfer/utils/utils.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm_ascend.distributed.parallel_state import get_p_tp_group
|
||||
|
||||
|
||||
def kv_alltoall_and_rearrange(pd_tp_ratio: int, key: torch.Tensor,
|
||||
value: torch.TensorType):
|
||||
if pd_tp_ratio <= 1:
|
||||
return None, None
|
||||
elif key is None or value is None:
|
||||
raise ValueError("key or value is None")
|
||||
k_output = alltoall_and_rearrange(pd_tp_ratio, key)
|
||||
v_output = alltoall_and_rearrange(pd_tp_ratio, value)
|
||||
return k_output, v_output
|
||||
|
||||
|
||||
def alltoall_and_rearrange(tp_ratio: int, input_tensor: torch.Tensor):
|
||||
num_kv_heads = input_tensor.size(1)
|
||||
output_tensor = torch.zeros_like(input_tensor)
|
||||
dist.all_to_all_single(output_tensor,
|
||||
input_tensor,
|
||||
group=get_p_tp_group().device_group)
|
||||
input_tensor = 0
|
||||
result = rearrange_output(output_tensor, tp_ratio, num_kv_heads)
|
||||
output_tensor = 0
|
||||
return result
|
||||
|
||||
|
||||
def rearrange_output(base_output: torch.Tensor, cut_num: int,
|
||||
num_kv_heads: int):
|
||||
size_0 = base_output.size(0)
|
||||
if size_0 % cut_num != 0:
|
||||
raise ValueError(
|
||||
f"The size of dim 0 [{size_0}] must be divisible by the cut_num [{cut_num}]"
|
||||
)
|
||||
chunk_size = size_0 // cut_num
|
||||
reshaped = base_output.view(cut_num, chunk_size, -1)
|
||||
transposed = reshaped.transpose(0, 1)
|
||||
return transposed.contiguous().view(size_0, num_kv_heads, -1)
|
||||
|
||||
|
||||
def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor:
|
||||
data_ptr = tensor.data_ptr()
|
||||
aligned_addr = (data_ptr + alignment - 1) // alignment * alignment
|
||||
offset = (aligned_addr - data_ptr) // tensor.element_size()
|
||||
return tensor[int(offset):]
|
||||
|
||||
|
||||
def get_transfer_timeout_value():
|
||||
ascend_transfer_timeout = os.getenv("ASCEND_TRANSFER_TIMEOUT", "")
|
||||
if len(ascend_transfer_timeout) > 0:
|
||||
return int(ascend_transfer_timeout)
|
||||
hccl_rdma_timeout = int(os.getenv('HCCL_RDMA_TIMEOUT',
|
||||
'20')) # type: ignore
|
||||
hccl_rdma_retry_cnt = int(os.getenv('HCCL_RDMA_RETRY_CNT',
|
||||
'7')) # type: ignore
|
||||
return int((4.096 * (2**hccl_rdma_timeout)) * hccl_rdma_retry_cnt // 1000 +
|
||||
3000)
|
||||
Reference in New Issue
Block a user