[Feature][main]reconstruction kvpool connector to ascend connector (#4438)
### What this PR does / why we need it? 1.In short, we renamed the existing MooncakeStoreConnector to AscendStoreConnector and extracted the storage engine interaction logic into a new Backend class. Associated RFC:https://github.com/vllm-project/vllm-ascend/issues/4329 2.Fixed the issue where the number of input parameters for the connector was incorrect, introduced in vllm 0.11.2 ### Does this PR introduce _any_ user-facing change? change MooncakeStoreConnector to AscendStoreConnector ### How was this patch tested? - vLLM version: v0.11.2 --------- Signed-off-by: fems14 <1804143737@qq.com>
This commit is contained in:
@@ -31,8 +31,13 @@ def register_connector():
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"MooncakeConnectorStoreV1",
|
||||
"vllm_ascend.distributed.mooncake.mooncake_store_connector_v1",
|
||||
"MooncakeConnectorV1")
|
||||
"vllm_ascend.distributed.kvpool.ascend_store_connector",
|
||||
"AscendStoreConnector")
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"AscendStoreConnector",
|
||||
"vllm_ascend.distributed.kvpool.ascend_store_connector",
|
||||
"AscendStoreConnector")
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"MooncakeLayerwiseConnector",
|
||||
|
||||
@@ -29,6 +29,7 @@ if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
|
||||
|
||||
@@ -58,7 +59,10 @@ class CPUOffloadingConnectorMetadata(KVConnectorMetadata):
|
||||
|
||||
class CPUOffloadingConnector(KVConnectorBase_V1):
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: Optional[KVCacheConfig] = None):
|
||||
if not vllm_config.cache_config.enable_prefix_caching:
|
||||
self.connector_scheduler: Optional[
|
||||
CPUOffloadingConnectorScheduler] = None
|
||||
|
||||
1
vllm_ascend/distributed/kvpool/__init__.py
Normal file
1
vllm_ascend/distributed/kvpool/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
194
vllm_ascend/distributed/kvpool/ascend_store_connector.py
Normal file
194
vllm_ascend/distributed/kvpool/ascend_store_connector.py
Normal file
@@ -0,0 +1,194 @@
|
||||
import threading
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import zmq
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.utils import logger
|
||||
from vllm.utils.network_utils import make_zmq_socket
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
from vllm.v1.serial_utils import MsgpackDecoder
|
||||
|
||||
from vllm_ascend.distributed.kvpool.pool_scheduler import (
|
||||
KVPoolScheduler, get_zmq_rpc_path_lookup)
|
||||
from vllm_ascend.distributed.kvpool.pool_worker import KVPoolWorker
|
||||
|
||||
|
||||
class AscendStoreConnector(KVConnectorBase_V1):
|
||||
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: Optional[KVCacheConfig] = None):
|
||||
super().__init__(vllm_config=vllm_config,
|
||||
role=role,
|
||||
kv_cache_config=kv_cache_config)
|
||||
self.kv_role = vllm_config.kv_transfer_config.kv_role
|
||||
|
||||
self.use_layerwise = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"use_layerwise", False)
|
||||
|
||||
connector_name = vllm_config.kv_transfer_config.kv_connector
|
||||
if connector_name == "MooncakeConnectorStoreV1":
|
||||
logger.warning(
|
||||
"It is recommended to use the AscendStoreConnector, as the MoonCakeStoreConnector will be removed in the future."
|
||||
)
|
||||
|
||||
self.kv_caches: dict[str, torch.Tensor] = {}
|
||||
|
||||
self._block_size = vllm_config.cache_config.block_size
|
||||
|
||||
self.sended_but_unfinished_reqs: set[str] = set()
|
||||
|
||||
if role == KVConnectorRole.SCHEDULER:
|
||||
self.connector_scheduler = KVPoolScheduler(vllm_config,
|
||||
self.use_layerwise)
|
||||
else:
|
||||
self.connector_worker = KVPoolWorker(
|
||||
vllm_config,
|
||||
self.use_layerwise,
|
||||
)
|
||||
|
||||
assert self.connector_worker is not None
|
||||
if vllm_config.parallel_config.rank == 0:
|
||||
self.lookup_server = LookupKeyServer(self.connector_worker,
|
||||
vllm_config,
|
||||
self.use_layerwise)
|
||||
|
||||
############################################################
|
||||
# Scheduler Side Methods
|
||||
############################################################
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request",
|
||||
num_computed_tokens: int) -> tuple[int, bool]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens)
|
||||
|
||||
def update_state_after_alloc(self, request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int):
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.update_state_after_alloc(
|
||||
request, blocks, num_external_tokens)
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> KVConnectorMetadata:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.build_connector_meta(scheduler_output)
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.request_finished(request, block_ids)
|
||||
|
||||
############################################################
|
||||
# Worker Side Methods
|
||||
############################################################
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
assert self.connector_worker is not None
|
||||
self.connector_worker.register_kv_caches(kv_caches)
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext",
|
||||
**kwargs) -> None:
|
||||
assert self.connector_worker is not None
|
||||
self.connector_worker.start_load_kv(self._get_connector_metadata())
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
if not self.use_layerwise:
|
||||
return
|
||||
self.connector_worker.wait_for_layer_load()
|
||||
|
||||
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
||||
attn_metadata: "AttentionMetadata", **kwargs) -> None:
|
||||
if not self.use_layerwise:
|
||||
return
|
||||
|
||||
if self.kv_role == "kv_consumer":
|
||||
# Don't do save if the role is kv_consumer
|
||||
return
|
||||
self.connector_worker.save_kv_layer(self._get_connector_metadata())
|
||||
|
||||
def wait_for_save(self):
|
||||
if self.kv_role == "kv_consumer":
|
||||
# Don't do save if the role is kv_consumer
|
||||
return
|
||||
|
||||
if self.use_layerwise:
|
||||
return
|
||||
|
||||
self.connector_worker.wait_for_save(self._get_connector_metadata())
|
||||
|
||||
def get_finished(self,
|
||||
finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
|
||||
"""Get the finished recving and sending requests."""
|
||||
assert self.connector_worker is not None
|
||||
meta = self._get_connector_metadata()
|
||||
done_sending, done_recving = self.connector_worker.get_finished()
|
||||
sended_and_finished: set[str] = set()
|
||||
for item in list(self.sended_but_unfinished_reqs):
|
||||
if item not in meta.unfinished_request_ids:
|
||||
sended_and_finished.add(item)
|
||||
self.sended_but_unfinished_reqs.remove(item)
|
||||
for item in done_sending:
|
||||
if item in meta.unfinished_request_ids:
|
||||
self.sended_but_unfinished_reqs.add(item)
|
||||
else:
|
||||
sended_and_finished.add(item)
|
||||
|
||||
return sended_and_finished, done_recving
|
||||
|
||||
|
||||
class LookupKeyServer:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pool_worker: KVPoolWorker,
|
||||
vllm_config: "VllmConfig",
|
||||
use_layerwise: bool,
|
||||
):
|
||||
self.decoder = MsgpackDecoder()
|
||||
self.decoder_tensor = MsgpackDecoder(torch.Tensor)
|
||||
self.ctx = zmq.Context() # type: ignore[attr-defined]
|
||||
socket_path = get_zmq_rpc_path_lookup(vllm_config)
|
||||
self.socket = make_zmq_socket(
|
||||
self.ctx,
|
||||
socket_path,
|
||||
zmq.REP, # type: ignore[attr-defined]
|
||||
bind=True,
|
||||
)
|
||||
|
||||
self.pool_worker = pool_worker
|
||||
self.running = True
|
||||
self.use_layerwise = use_layerwise
|
||||
|
||||
def process_request():
|
||||
while self.running:
|
||||
all_frames = self.socket.recv_multipart(copy=False)
|
||||
token_len = int.from_bytes(all_frames[0], byteorder="big")
|
||||
hash_frames = all_frames[1:]
|
||||
hashes_str = self.decoder.decode(hash_frames)
|
||||
result = self.pool_worker.lookup_scheduler(
|
||||
token_len, hashes_str, self.use_layerwise)
|
||||
response = result.to_bytes(4, "big")
|
||||
self.socket.send(response)
|
||||
|
||||
self.thread = threading.Thread(target=process_request, daemon=True)
|
||||
self.thread.start()
|
||||
|
||||
def close(self):
|
||||
self.socket.close(linger=0)
|
||||
# TODO: close the thread!
|
||||
1
vllm_ascend/distributed/kvpool/backend/__init__.py
Normal file
1
vllm_ascend/distributed/kvpool/backend/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
29
vllm_ascend/distributed/kvpool/backend/backend.py
Normal file
29
vllm_ascend/distributed/kvpool/backend/backend.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
|
||||
|
||||
class Backend(ABC):
|
||||
|
||||
def __init__(self, parallel_config: ParallelConfig):
|
||||
pass
|
||||
|
||||
def set_device(self):
|
||||
pass
|
||||
|
||||
def register_buffer(self, ptrs: list[int], lengths: list[int]):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def exists(self, keys: list[str]) -> list[int]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def put(self, keys: list[str], addrs: list[list[int]],
|
||||
sizes: list[list[int]]):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get(self, keys: list[str], addrs: list[list[int]],
|
||||
sizes: list[list[int]]):
|
||||
pass
|
||||
74
vllm_ascend/distributed/kvpool/backend/memcache_backend.py
Normal file
74
vllm_ascend/distributed/kvpool/backend/memcache_backend.py
Normal file
@@ -0,0 +1,74 @@
|
||||
# Standard
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.utils import logger
|
||||
|
||||
from vllm_ascend.distributed.kvpool.backend.backend import Backend
|
||||
|
||||
|
||||
class MmcDirect(Enum):
|
||||
COPY_L2G = 0
|
||||
COPY_G2L = 1
|
||||
COPY_G2H = 2
|
||||
COPY_H2G = 3
|
||||
|
||||
|
||||
class MemcacheBackend(Backend):
|
||||
|
||||
def __init__(self, parallel_config: ParallelConfig):
|
||||
try:
|
||||
from memcache import DistributedObjectStore # type: ignore
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Please install memcache by following the instructions at "
|
||||
"https://gitee.com/ascend/memfabric_hybrid " # noqa: E501
|
||||
"to run vLLM with MemcacheConnector.") from e
|
||||
try:
|
||||
self.rank = parallel_config.rank
|
||||
self.store = DistributedObjectStore()
|
||||
res = self.store.init(self.rank)
|
||||
assert res == 0
|
||||
except ValueError as e:
|
||||
logger.error("Configuration loading failed: %s", e)
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"An error occurred while loading the configuration: %s", exc)
|
||||
raise
|
||||
|
||||
def set_device(self):
|
||||
device = torch.device(f"npu:{self.rank}")
|
||||
torch.npu.set_device(device)
|
||||
|
||||
def register_buffer(self, ptrs: list[int], sizes: list[int]):
|
||||
for ptr, size in zip(ptrs, sizes):
|
||||
ret_value = self.store.register_buffer(ptr, size)
|
||||
if ret_value != 0:
|
||||
raise RuntimeError("Memcache memory registration failed.")
|
||||
|
||||
def exists(self, keys: list[str]) -> list[int]:
|
||||
return self.store.batch_is_exist(keys)
|
||||
|
||||
def get(self, key: list[str], addr: list[list[int]],
|
||||
size: list[list[int]]):
|
||||
try:
|
||||
res = self.store.batch_get_into_layers(key, addr, size,
|
||||
MmcDirect.COPY_G2L.value)
|
||||
for value in res:
|
||||
if value != 0:
|
||||
logger.error(f"Failed to get key {key},res:{res}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get key {key}. {e}")
|
||||
|
||||
def put(self, key: list[str], addr: list[list[int]],
|
||||
size: list[list[int]]):
|
||||
try:
|
||||
res = self.store.batch_put_from_layers(key, addr, size,
|
||||
MmcDirect.COPY_L2G.value)
|
||||
for value in res:
|
||||
if value != 0:
|
||||
logger.error(f"Failed to get key {key},res:{res}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to put key {key},error:{e}")
|
||||
188
vllm_ascend/distributed/kvpool/backend/mooncake_backend.py
Normal file
188
vllm_ascend/distributed/kvpool/backend/mooncake_backend.py
Normal file
@@ -0,0 +1,188 @@
|
||||
# Standard
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Union
|
||||
|
||||
# Third Party
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.utils import logger
|
||||
from vllm.utils.network_utils import get_ip
|
||||
|
||||
from vllm_ascend.distributed.kvpool.backend.backend import Backend
|
||||
from vllm_ascend.distributed.mooncake_transfer_engine import global_te
|
||||
|
||||
DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB
|
||||
DEFAULT_LOCAL_BUFFER_SIZE = 1073741824 # 1.0 GiB
|
||||
|
||||
|
||||
class MooncakeBackend(Backend):
|
||||
|
||||
def __init__(self, parallel_config: ParallelConfig):
|
||||
try:
|
||||
from mooncake.store import MooncakeDistributedStore # type: ignore
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Please install mooncake by following the instructions at "
|
||||
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
|
||||
"to run vLLM with MooncakeConnector.") from e
|
||||
self.config = MooncakeStoreConfig.load_from_env()
|
||||
self.store = MooncakeDistributedStore()
|
||||
if self.config.protocol == "ascend":
|
||||
local_hostname = get_ip()
|
||||
transfer_engine = global_te.get_transfer_engine(local_hostname,
|
||||
device_name=None)
|
||||
self.local_seg = local_hostname + ":" + str(
|
||||
transfer_engine.get_rpc_port())
|
||||
ret = self.store.setup(self.local_seg, self.config.metadata_server,
|
||||
self.config.global_segment_size,
|
||||
self.config.local_buffer_size,
|
||||
self.config.protocol,
|
||||
self.config.device_name,
|
||||
self.config.master_server_address,
|
||||
transfer_engine.get_engine())
|
||||
if ret != 0:
|
||||
msg = "Initialize mooncake failed."
|
||||
logger.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
|
||||
def register_buffer(self, ptrs: list[int], lengths: list[int]):
|
||||
global_te.register_buffer(ptrs, lengths)
|
||||
|
||||
def exists(self, keys: list[str]) -> list[int]:
|
||||
return self.store.batch_is_exist(keys)
|
||||
|
||||
def put(self, keys: list[str], addrs: list[list[int]],
|
||||
sizes: list[list[int]]):
|
||||
try:
|
||||
res = self.store.batch_put_from_multi_buffers(keys, addrs, sizes)
|
||||
for value in res:
|
||||
if value < 0:
|
||||
logger.error(f"Failed to put key {keys},res:{res}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to put key {keys},error:{e}")
|
||||
|
||||
def get(self, keys: list[str], addrs: list[list[int]],
|
||||
sizes: list[list[int]]):
|
||||
try:
|
||||
res = self.store.batch_get_into_multi_buffers(keys, addrs, sizes)
|
||||
for value in res:
|
||||
if value < 0:
|
||||
logger.error(f"Failed to get key {keys}, res:{res}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get key {keys}, error:{e}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class MooncakeStoreConfig:
|
||||
local_hostname: str
|
||||
metadata_server: str
|
||||
global_segment_size: Union[int, str]
|
||||
local_buffer_size: int
|
||||
protocol: str
|
||||
device_name: str
|
||||
master_server_address: str
|
||||
use_ascend_direct: bool
|
||||
|
||||
@staticmethod
|
||||
def from_file(file_path: str) -> "MooncakeStoreConfig":
|
||||
with open(file_path) as file:
|
||||
config = json.load(file)
|
||||
return MooncakeStoreConfig(
|
||||
local_hostname=config.get("local_hostname"),
|
||||
metadata_server=config.get("metadata_server"),
|
||||
global_segment_size=_parse_global_segment_size(
|
||||
config.get("global_segment_size",
|
||||
DEFAULT_GLOBAL_SEGMENT_SIZE)),
|
||||
local_buffer_size=(config.get("local_buffer_size",
|
||||
DEFAULT_LOCAL_BUFFER_SIZE)),
|
||||
protocol=config.get("protocol", "tcp"),
|
||||
device_name=config.get("device_name", ""),
|
||||
master_server_address=config.get("master_server_address"),
|
||||
use_ascend_direct=config.get("use_ascend_direct", False))
|
||||
|
||||
@staticmethod
|
||||
def load_from_env() -> "MooncakeStoreConfig":
|
||||
config_path = os.getenv("MOONCAKE_CONFIG_PATH")
|
||||
if not config_path:
|
||||
raise ValueError(
|
||||
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set.")
|
||||
return MooncakeStoreConfig.from_file(config_path)
|
||||
|
||||
|
||||
def _parse_global_segment_size(value) -> int:
|
||||
"""
|
||||
Parse storage size strings with support for units: GB, MB, KB, B
|
||||
|
||||
Args:
|
||||
value: Input value (int, str, or other convertible types)
|
||||
|
||||
Returns:
|
||||
int: Size in bytes
|
||||
|
||||
Raises:
|
||||
ValueError: For invalid format, missing number, or negative values
|
||||
TypeError: For unsupported input types
|
||||
"""
|
||||
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
elif not isinstance(value, str):
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError) as e:
|
||||
raise TypeError(
|
||||
f"Unsupported type for global_segment_size: {type(value)}"
|
||||
) from e
|
||||
|
||||
cleaned_input = value.strip().lower()
|
||||
if not cleaned_input:
|
||||
raise ValueError("global segment size cannot be empty.")
|
||||
|
||||
UNIT_MULTIPLIERS = {
|
||||
'gb': 1024**3, # 1 GB = 1024^3 bytes
|
||||
'mb': 1024**2, # 1 MB = 1024^2 bytes
|
||||
'kb': 1024, # 1 KB = 1024 bytes
|
||||
'b': 1 # 1 B = 1 byte
|
||||
}
|
||||
pattern = r'^\s*([\d.]+)\s*(gb|mb|kb|b)?\s*$'
|
||||
match = re.match(pattern, cleaned_input)
|
||||
|
||||
if not match:
|
||||
raise ValueError(f"Invalid format: '{value}'")
|
||||
|
||||
number_str = match.group(1)
|
||||
unit = match.group(2) or 'b'
|
||||
|
||||
multiplier = UNIT_MULTIPLIERS[unit]
|
||||
return _convert_to_bytes(number_str, multiplier, value)
|
||||
|
||||
|
||||
def _convert_to_bytes(number_str: str, multiplier: int,
|
||||
original_input: str) -> int:
|
||||
"""
|
||||
Convert numeric string to byte count
|
||||
|
||||
Args:
|
||||
number_str: Numeric portion of input
|
||||
multiplier: Unit conversion factor
|
||||
original_input: Original input string (for error messages)
|
||||
|
||||
Returns:
|
||||
int: Byte count
|
||||
|
||||
Raises:
|
||||
ValueError: For invalid numbers or negative results
|
||||
"""
|
||||
try:
|
||||
numeric_value = float(number_str)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Invalid numeric value '{number_str}' in: '{original_input}'")
|
||||
# Calculate byte count
|
||||
try:
|
||||
byte_count = int(numeric_value * multiplier)
|
||||
except OverflowError:
|
||||
raise ValueError(f"Storage size too large: '{original_input}'")
|
||||
return byte_count
|
||||
364
vllm_ascend/distributed/kvpool/config_data.py
Normal file
364
vllm_ascend/distributed/kvpool/config_data.py
Normal file
@@ -0,0 +1,364 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import \
|
||||
KVConnectorMetadata
|
||||
from vllm.utils import logger
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash
|
||||
from vllm.v1.core.sched.output import NewRequestData
|
||||
|
||||
|
||||
#Parameters related to the key
|
||||
@dataclass
|
||||
class KeyMetadata:
|
||||
"""name of the LLM model"""
|
||||
|
||||
model_name: str
|
||||
""" worker id when running under a distributed setting """
|
||||
head_or_tp_rank: int
|
||||
|
||||
|
||||
@dataclass(order=True)
|
||||
class PoolKey:
|
||||
key_metadata: KeyMetadata
|
||||
chunk_hash: str
|
||||
|
||||
def __hash__(self):
|
||||
return hash((
|
||||
self.key_metadata.model_name,
|
||||
self.key_metadata.head_or_tp_rank,
|
||||
self.chunk_hash,
|
||||
))
|
||||
|
||||
def to_string(self):
|
||||
return (
|
||||
f"{self.key_metadata.model_name}"
|
||||
f"@head_or_tp_rank:{self.key_metadata.head_or_tp_rank}@{self.chunk_hash}"
|
||||
)
|
||||
|
||||
def split_layers(self, num_layers: int) -> List["LayerPoolKey"]:
|
||||
"""Split the key into multiple keys for each layer"""
|
||||
keys = []
|
||||
for layer_id in range(num_layers):
|
||||
keys.append(
|
||||
LayerPoolKey(
|
||||
self.key_metadata,
|
||||
self.chunk_hash,
|
||||
layer_id,
|
||||
))
|
||||
return keys
|
||||
|
||||
|
||||
@dataclass(order=True)
|
||||
class LayerPoolKey(PoolKey):
|
||||
"""A key for the layer cache engine"""
|
||||
|
||||
layer_id: int
|
||||
|
||||
def __hash__(self):
|
||||
return hash((
|
||||
self.key_metadata.model_name,
|
||||
self.key_metadata.head_or_tp_rank,
|
||||
self.chunk_hash,
|
||||
self.layer_id,
|
||||
))
|
||||
|
||||
def to_string(self):
|
||||
return (
|
||||
f"{self.key_metadata.model_name}"
|
||||
f"@head_or_tp_rank:{self.key_metadata.head_or_tp_rank}@{self.chunk_hash}@{self.layer_id}"
|
||||
)
|
||||
|
||||
|
||||
class ChunkedTokenDatabase():
|
||||
|
||||
def __init__(self, metadata: KeyMetadata, block_size: int, use_mla: bool):
|
||||
self.metadata = metadata
|
||||
self.block_size = block_size
|
||||
self.use_mla = use_mla
|
||||
self.kv_caches_base_addr: list[int] = []
|
||||
self.block_len: list[int] = []
|
||||
|
||||
def _make_key_by_hash(self,
|
||||
chunk_hash: str,
|
||||
layer_id: Optional[int] = None):
|
||||
assert self.metadata is not None
|
||||
return PoolKey(
|
||||
self.metadata,
|
||||
chunk_hash,
|
||||
)
|
||||
|
||||
def set_kv_caches_base_addr(self, kv_caches_base_addr: list[int]):
|
||||
self.kv_caches_base_addr = kv_caches_base_addr
|
||||
|
||||
def set_block_len(self, block_len: list[int]):
|
||||
self.block_len = block_len
|
||||
|
||||
def prepare_value(self, start: int, end: int, block_ids: list[int]):
|
||||
addr_list = []
|
||||
size_list = []
|
||||
block_id = block_ids[start // self.block_size]
|
||||
for index, base_addr in enumerate(self.kv_caches_base_addr):
|
||||
block_len = (self.block_len[index % 2]
|
||||
if self.use_mla else self.block_len[0])
|
||||
|
||||
addr = base_addr + block_id * block_len
|
||||
length = int(block_len / self.block_size * (end - start))
|
||||
addr_list.append(addr)
|
||||
size_list.append(length)
|
||||
return addr_list, size_list, block_id
|
||||
|
||||
def prepare_value_layer(self, start: int, end: int, block_ids: list[int],
|
||||
layer_id: int):
|
||||
block_id = block_ids[start // self.block_size]
|
||||
if self.use_mla:
|
||||
addr_k = self.kv_caches_base_addr[layer_id *
|
||||
2] + block_id * self.block_len[0]
|
||||
addr_v = self.kv_caches_base_addr[layer_id * 2 +
|
||||
1] + block_id * self.block_len[1]
|
||||
length_k = int(self.block_len[0] / self.block_size * (end - start))
|
||||
length_v = int(self.block_len[1] / self.block_size * (end - start))
|
||||
size_list = [length_k, length_v]
|
||||
else:
|
||||
addr_k = self.kv_caches_base_addr[layer_id *
|
||||
2] + block_id * self.block_len[0]
|
||||
addr_v = self.kv_caches_base_addr[layer_id * 2 +
|
||||
1] + block_id * self.block_len[0]
|
||||
length = int(self.block_len[0] / self.block_size * (end - start))
|
||||
size_list = [length, length]
|
||||
addr_list = [addr_k, addr_v]
|
||||
return addr_list, size_list
|
||||
|
||||
def process_tokens(
|
||||
self,
|
||||
token_len: int,
|
||||
block_hashes: Union[list[BlockHash], list[str]],
|
||||
mask_num: int = 0,
|
||||
) -> Iterable[Tuple[int, int, PoolKey]]:
|
||||
"""Process the tokens and return the corresponding cache engine keys.
|
||||
|
||||
:param Union[torch.Tensor, List[int]] tokens: The tokens to process.
|
||||
|
||||
:param Optional[torch.Tensor] mask: The mask for the tokens. Should
|
||||
have the same length as tokens. And the mask should ALWAYS be like
|
||||
FFFFFTTTTTTT, where True means the tokens needs to be matched,
|
||||
and the Falses will ALWAYS be at the PREFIX of the tensor.
|
||||
|
||||
:param bool make_key: Whether to make the cache engine key or not.
|
||||
If False, the hash value will be returned instead.
|
||||
|
||||
:returns: A iterable of tuples with three elements. The first element
|
||||
is the start index of the tokens for the key. The second element
|
||||
is the end index of the tokens for the key. The third element is
|
||||
the cache engine key (or hash) for the tokens.
|
||||
|
||||
:raises: ValueError if the number of Falses in the mask is not a
|
||||
multiple of the chunk size.
|
||||
"""
|
||||
if not block_hashes:
|
||||
return
|
||||
if not isinstance(block_hashes[0], str):
|
||||
block_hashes = [
|
||||
h.hex() # type: ignore[union-attr]
|
||||
for h in block_hashes
|
||||
]
|
||||
start_idx = 0
|
||||
for chunk_id, hash_val in enumerate(block_hashes):
|
||||
start_idx = chunk_id * self.block_size
|
||||
if start_idx >= token_len:
|
||||
break
|
||||
end_idx = min(start_idx + self.block_size, token_len)
|
||||
if start_idx < mask_num:
|
||||
continue
|
||||
else:
|
||||
yield start_idx, end_idx, self._make_key_by_hash(hash_val)
|
||||
|
||||
|
||||
#Parameters related to the connector metadata
|
||||
@dataclass
|
||||
class LoadSpec:
|
||||
# Number of tokens cached in vLLM
|
||||
vllm_cached_tokens: int
|
||||
# Number of tokens that are cached in kvpool
|
||||
kvpool_cached_tokens: int
|
||||
# Whether the scheduler allow us to load the tokens
|
||||
can_load: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestTracker:
|
||||
# Request id
|
||||
req_id: str
|
||||
|
||||
# The token ids that has been scheduled so far
|
||||
token_len: int
|
||||
|
||||
# The block ids that has been allocated so far
|
||||
# NOTE: allocated blocks could be more than the number of tokens
|
||||
# FIXME: need to check whether the block ids will be changed after
|
||||
# preemption
|
||||
allocated_block_ids: list[int]
|
||||
|
||||
# The number of tokens that has been savd
|
||||
num_saved_tokens: int = 0
|
||||
|
||||
@staticmethod
|
||||
def from_new_request(
|
||||
new_request: "NewRequestData",
|
||||
num_tokens_to_compute: int,
|
||||
) -> "RequestTracker":
|
||||
"""Create the request tracker from a new request.
|
||||
|
||||
Args:
|
||||
new_request (NewRequestData): the new request data.
|
||||
num_tokens_to_compute (int): the number of tokens that will
|
||||
be 'computed', including the `num_computed_tokens` (vLLM's
|
||||
local cache hit) and new tokens that will be scheduled.
|
||||
|
||||
"""
|
||||
unfolded_block_ids = []
|
||||
|
||||
if not isinstance(new_request.block_ids[0], list):
|
||||
unfolded_block_ids = new_request.block_ids.copy()
|
||||
else:
|
||||
unfolded_block_ids = new_request.block_ids[0].copy()
|
||||
|
||||
return RequestTracker(
|
||||
req_id=new_request.req_id,
|
||||
token_len=num_tokens_to_compute,
|
||||
allocated_block_ids=unfolded_block_ids,
|
||||
num_saved_tokens=0,
|
||||
)
|
||||
|
||||
def update(
|
||||
self,
|
||||
new_token_ids: list[int],
|
||||
new_block_ids: Union[tuple[list[int], ...], list[int]],
|
||||
) -> None:
|
||||
"""Update the request tracker when a running request is
|
||||
scheduled again
|
||||
"""
|
||||
|
||||
self.token_len = self.token_len + len(new_token_ids)
|
||||
|
||||
if len(new_block_ids) == 0:
|
||||
new_block_ids = []
|
||||
elif isinstance(new_block_ids, tuple):
|
||||
new_block_ids = new_block_ids[0]
|
||||
elif isinstance(new_block_ids, list):
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported new_block_ids type {type(new_block_ids)}")
|
||||
self.allocated_block_ids.extend(new_block_ids)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReqMeta:
|
||||
# Request id
|
||||
req_id: str
|
||||
# Request tokens
|
||||
token_len_chunk: int
|
||||
|
||||
block_ids: list[int]
|
||||
|
||||
block_hashes: list[BlockHash]
|
||||
|
||||
can_save: Optional[bool] = None
|
||||
# load_spec
|
||||
load_spec: Optional[LoadSpec] = None
|
||||
|
||||
is_last_chunk: Optional[bool] = None
|
||||
|
||||
@staticmethod
|
||||
def from_request_tracker(
|
||||
tracker: RequestTracker,
|
||||
block_size: int,
|
||||
load_spec: Optional[LoadSpec] = None,
|
||||
skip_save: Optional[bool] = False,
|
||||
block_hashes: list[BlockHash] = [],
|
||||
is_last_chunk: Optional[bool] = None,
|
||||
discard_partial_chunks: bool = True,
|
||||
) -> Optional["ReqMeta"]:
|
||||
"""Create the request metadata from a request tracker.
|
||||
|
||||
Args:
|
||||
tracker (RequestTracker): the request tracker.
|
||||
block_size (int): the block size in vLLM.
|
||||
load_spec (Optional[LoadSpec]): the load spec for KV cache loading.
|
||||
skip_save (bool): whether to skip the save operation.
|
||||
discard_partial_chunks (bool): whether to discard partial chunks.
|
||||
|
||||
Returns:
|
||||
the request metadata if we need to perform load/save
|
||||
operations, None otherwise.
|
||||
"""
|
||||
input_token_len = tracker.token_len
|
||||
|
||||
# For save operation: do not save if the following condition is met
|
||||
# 1. has already been saved before (num_saved_tokens > 0)
|
||||
# 2. number of unsaved tokens is not reached the chunk boundary
|
||||
chunk_boundary = (cdiv(tracker.num_saved_tokens + 1, block_size) *
|
||||
block_size if discard_partial_chunks else 0)
|
||||
# Calculate number of tokens to save based on discard_partial_chunks
|
||||
# setting
|
||||
num_tokens_to_save = ((input_token_len // block_size * block_size)
|
||||
if discard_partial_chunks else input_token_len)
|
||||
|
||||
skip_save = skip_save or num_tokens_to_save < chunk_boundary
|
||||
if skip_save and load_spec is None:
|
||||
return None
|
||||
|
||||
# If we need to save, update the number of saved tokens
|
||||
if not skip_save:
|
||||
tracker.num_saved_tokens = num_tokens_to_save
|
||||
|
||||
# # For load operation: check whether the request is scheduled to load
|
||||
if load_spec is not None and load_spec.can_load:
|
||||
logger.debug(
|
||||
"Scheduled to load %d tokens for request %s",
|
||||
load_spec.kvpool_cached_tokens,
|
||||
tracker.req_id,
|
||||
)
|
||||
else:
|
||||
# Do not load if not in `can_load` state
|
||||
load_spec = None
|
||||
logger.debug(
|
||||
f"request:{tracker.req_id}, meta save spec:{not skip_save}, meta load spec:{load_spec}"
|
||||
)
|
||||
return ReqMeta(
|
||||
req_id=tracker.req_id,
|
||||
token_len_chunk=num_tokens_to_save,
|
||||
block_ids=tracker.allocated_block_ids,
|
||||
can_save=not skip_save,
|
||||
load_spec=load_spec,
|
||||
block_hashes=block_hashes,
|
||||
is_last_chunk=is_last_chunk,
|
||||
)
|
||||
|
||||
|
||||
class AscendConnectorMetadata(KVConnectorMetadata):
|
||||
|
||||
def __init__(self, unfinished_request_ids):
|
||||
self.requests = []
|
||||
self.unfinished_request_ids = unfinished_request_ids
|
||||
|
||||
def add_request(self, req_meta: ReqMeta) -> None:
|
||||
"""Add a request to the metadata.
|
||||
|
||||
Args:
|
||||
req_meta (ReqMeta): the request metadata.
|
||||
"""
|
||||
self.requests.append(req_meta)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LasyerMultiBlockReqMeta:
|
||||
req_id: str
|
||||
keys: List[LayerPoolKey]
|
||||
starts: List[int]
|
||||
ends: list[int]
|
||||
block_ids: list[int]
|
||||
layer_id: int
|
||||
is_last_chunk: bool = True
|
||||
246
vllm_ascend/distributed/kvpool/kv_transfer.py
Normal file
246
vllm_ascend/distributed/kvpool/kv_transfer.py
Normal file
@@ -0,0 +1,246 @@
|
||||
import queue
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from vllm.utils import logger
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash
|
||||
|
||||
from vllm_ascend.distributed.kvpool.backend.backend import Backend
|
||||
|
||||
# isort: off
|
||||
from vllm_ascend.distributed.kvpool.config_data import (ChunkedTokenDatabase,
|
||||
LasyerMultiBlockReqMeta
|
||||
)
|
||||
# isort: on
|
||||
|
||||
|
||||
class KVTransferThread(threading.Thread):
|
||||
|
||||
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
|
||||
tp_rank: int, ready_event: threading.Event, name: str):
|
||||
super().__init__(daemon=True, name=name)
|
||||
self.m_store = m_store
|
||||
self.ready_event = ready_event
|
||||
self.tp_rank = tp_rank
|
||||
self.token_database = token_database
|
||||
self.done_task_lock = threading.Lock()
|
||||
self.request_queue: queue.Queue[Any] = queue.Queue()
|
||||
# TODO(jianzs): make this configurable
|
||||
self.executor = ThreadPoolExecutor(max_workers=32)
|
||||
self.finished_requests: set[str] = set()
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
req_id: str,
|
||||
token_len: int,
|
||||
block_ids: list[int],
|
||||
block_hashes: list[BlockHash],
|
||||
mask_num: int = 0,
|
||||
is_last_chunk: Optional[bool] = None,
|
||||
) -> torch.Tensor:
|
||||
req = ({
|
||||
"req_id": req_id,
|
||||
"token_len": token_len,
|
||||
"block_ids": block_ids,
|
||||
"block_hashes": block_hashes,
|
||||
"mask_num": mask_num,
|
||||
"is_last_chunk": is_last_chunk,
|
||||
})
|
||||
self.request_queue.put(req)
|
||||
|
||||
def get_and_clear_finished_requests(self) -> set[str]:
|
||||
"""
|
||||
Get and clear the requests that have been completed.
|
||||
Returns:
|
||||
A set of request IDs that have been completed.
|
||||
"""
|
||||
with self.done_task_lock:
|
||||
finished_requests = self.finished_requests.copy()
|
||||
self.finished_requests.clear()
|
||||
return finished_requests
|
||||
|
||||
def set_finished_request(self, req_id):
|
||||
with self.done_task_lock:
|
||||
self.finished_requests.add(req_id)
|
||||
|
||||
def run(self):
|
||||
"""Run the thread to handle KV cache transfer requests."""
|
||||
self.m_store.set_device()
|
||||
self.ready_event.set()
|
||||
while True:
|
||||
try:
|
||||
request_data = self.request_queue.get()
|
||||
if request_data is None:
|
||||
logger.warning("Received a None request!")
|
||||
self.request_queue.task_done()
|
||||
continue
|
||||
self._handle_request(request_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in KVCacheTransferThread: {e}")
|
||||
|
||||
def _handle_request(self, req_meta: dict[str, Any]):
|
||||
pass
|
||||
|
||||
|
||||
class KVCacheStoreSendingThread(KVTransferThread):
|
||||
|
||||
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
|
||||
tp_rank: int, put_step: int, ready_event: threading.Event):
|
||||
super().__init__(m_store,
|
||||
token_database,
|
||||
tp_rank,
|
||||
ready_event,
|
||||
name="KVCacheSendingThread")
|
||||
self.put_step = put_step
|
||||
|
||||
def _handle_request(self, req_meta: dict[str, Any]):
|
||||
token_len = req_meta["token_len"]
|
||||
mask_num = req_meta["mask_num"]
|
||||
block_ids = req_meta["block_ids"]
|
||||
block_hashes = req_meta["block_hashes"]
|
||||
req_id = req_meta["req_id"]
|
||||
is_last_chunk = req_meta["is_last_chunk"]
|
||||
addr_list = []
|
||||
size_list = []
|
||||
key_list = []
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
token_len, block_hashes, mask_num):
|
||||
addr, size, _ = self.token_database.prepare_value(
|
||||
start, end, block_ids)
|
||||
key_list.append(key.to_string())
|
||||
addr_list.append(addr)
|
||||
size_list.append(size)
|
||||
key_list_tp = key_list[self.tp_rank % self.put_step::self.put_step]
|
||||
addr_list_tp = addr_list[self.tp_rank % self.put_step::self.put_step]
|
||||
size_list_tp = size_list[self.tp_rank % self.put_step::self.put_step]
|
||||
if key_list_tp:
|
||||
torch.npu.current_stream().synchronize()
|
||||
self.m_store.put(key_list_tp, addr_list_tp, size_list_tp)
|
||||
if is_last_chunk:
|
||||
self.set_finished_request(req_id)
|
||||
self.request_queue.task_done()
|
||||
|
||||
|
||||
class KVCacheStoreRecvingThread(KVTransferThread):
|
||||
|
||||
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
|
||||
tp_rank: int, ready_event: threading.Event):
|
||||
super().__init__(m_store,
|
||||
token_database,
|
||||
tp_rank,
|
||||
ready_event,
|
||||
name="KVCacheStoreRecvingThread")
|
||||
|
||||
def _handle_request(self, req_meta: dict[str, Any]):
|
||||
token_len = req_meta["token_len"]
|
||||
mask_num = req_meta["mask_num"]
|
||||
block_ids = req_meta["block_ids"]
|
||||
req_id = req_meta["req_id"]
|
||||
block_hashes = req_meta["block_hashes"]
|
||||
addr_list = []
|
||||
size_list = []
|
||||
key_list = []
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
token_len, block_hashes, mask_num):
|
||||
addr, size, _ = self.token_database.prepare_value(
|
||||
start, end, block_ids)
|
||||
key_list.append(key.to_string())
|
||||
addr_list.append(addr)
|
||||
size_list.append(size)
|
||||
key_list_c = key_list[self.tp_rank %
|
||||
len(key_list):] + key_list[:self.tp_rank %
|
||||
len(key_list)]
|
||||
addr_list_c = addr_list[self.tp_rank %
|
||||
len(addr_list):] + addr_list[:self.tp_rank %
|
||||
len(addr_list)]
|
||||
size_list_c = size_list[self.tp_rank %
|
||||
len(size_list):] + size_list[:self.tp_rank %
|
||||
len(size_list)]
|
||||
self.m_store.get(key_list_c, addr_list_c, size_list_c)
|
||||
self.set_finished_request(req_id)
|
||||
self.request_queue.task_done()
|
||||
|
||||
|
||||
class KVCacheStoreLayerSendingThread(KVTransferThread):
|
||||
|
||||
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
|
||||
tp_rank: int, put_step: int, ready_event: threading.Event,
|
||||
num_layers: int):
|
||||
super().__init__(m_store,
|
||||
token_database,
|
||||
tp_rank,
|
||||
ready_event,
|
||||
name="KVCacheStoreLayerSendingThread")
|
||||
self.final_layer_id = num_layers - 1
|
||||
self.put_step = put_step
|
||||
|
||||
def add_request( # type: ignore[override]
|
||||
self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor:
|
||||
self.request_queue.put(req_meta)
|
||||
|
||||
def _handle_request( # type: ignore[override]
|
||||
self, req_meta: LasyerMultiBlockReqMeta):
|
||||
addr_list = []
|
||||
size_list = []
|
||||
key_list = []
|
||||
for index, key in enumerate(req_meta.keys):
|
||||
addr, size = self.token_database.prepare_value_layer(
|
||||
req_meta.starts[index], req_meta.ends[index],
|
||||
req_meta.block_ids, req_meta.layer_id)
|
||||
key_list.append(key.to_string())
|
||||
addr_list.append(addr)
|
||||
size_list.append(size)
|
||||
key_list_tp = key_list[self.tp_rank % self.put_step::self.put_step]
|
||||
addr_list_tp = addr_list[self.tp_rank % self.put_step::self.put_step]
|
||||
size_list_tp = size_list[self.tp_rank % self.put_step::self.put_step]
|
||||
if key_list_tp:
|
||||
torch.npu.current_stream().synchronize()
|
||||
self.m_store.put(key_list_tp, addr_list_tp, size_list_tp)
|
||||
if req_meta.layer_id == self.final_layer_id and req_meta.is_last_chunk:
|
||||
self.set_finished_request(req_meta.req_id)
|
||||
self.request_queue.task_done()
|
||||
|
||||
|
||||
class KVCacheStoreLayerRecvingThread(KVTransferThread):
|
||||
|
||||
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
|
||||
tp_rank: int, ready_event: threading.Event,
|
||||
get_event: threading.Event):
|
||||
super().__init__(m_store,
|
||||
token_database,
|
||||
tp_rank,
|
||||
ready_event,
|
||||
name="KVCacheStoreLayerRecvingThread")
|
||||
self.get_event = get_event
|
||||
|
||||
def add_request( # type: ignore[override]
|
||||
self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor:
|
||||
self.request_queue.put(req_meta)
|
||||
|
||||
def _handle_request( # type: ignore[override]
|
||||
self, req_meta: LasyerMultiBlockReqMeta):
|
||||
addr_list = []
|
||||
size_list = []
|
||||
key_list = []
|
||||
for index, key in enumerate(req_meta.keys):
|
||||
addr, size = self.token_database.prepare_value_layer(
|
||||
req_meta.starts[index], req_meta.ends[index],
|
||||
req_meta.block_ids, req_meta.layer_id)
|
||||
key_list.append(key.to_string())
|
||||
addr_list.append(addr)
|
||||
size_list.append(size)
|
||||
key_list_c = key_list[self.tp_rank %
|
||||
len(key_list):] + key_list[:self.tp_rank %
|
||||
len(key_list)]
|
||||
addr_list_c = addr_list[self.tp_rank %
|
||||
len(addr_list):] + addr_list[:self.tp_rank %
|
||||
len(addr_list)]
|
||||
size_list_c = size_list[self.tp_rank %
|
||||
len(size_list):] + size_list[:self.tp_rank %
|
||||
len(size_list)]
|
||||
self.m_store.get(key_list_c, addr_list_c, size_list_c)
|
||||
|
||||
self.request_queue.task_done()
|
||||
self.get_event.set()
|
||||
@@ -1,174 +1,33 @@
|
||||
import threading
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import vllm.envs as envs
|
||||
import zmq
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import \
|
||||
KVConnectorMetadata
|
||||
from vllm.utils import logger
|
||||
from vllm.utils.network_utils import make_zmq_socket
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.request import Request
|
||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
||||
from vllm.v1.serial_utils import MsgpackEncoder
|
||||
|
||||
from vllm_ascend.distributed.mooncake.config_data import (
|
||||
LoadSpec, MooncakeConnectorMetadata, ReqMeta, RequestTracker)
|
||||
from vllm_ascend.distributed.mooncake.mooncake_engine import MooncakeEngine
|
||||
from vllm_ascend.distributed.kvpool.config_data import (
|
||||
AscendConnectorMetadata, LoadSpec, ReqMeta, RequestTracker)
|
||||
|
||||
|
||||
class MooncakeConnectorV1(KVConnectorBase_V1):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
|
||||
super().__init__(vllm_config=vllm_config, role=role)
|
||||
self.kv_role = vllm_config.kv_transfer_config.kv_role
|
||||
|
||||
self.use_layerwise = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"use_layerwise", False)
|
||||
|
||||
self.kv_caches: dict[str, torch.Tensor] = {}
|
||||
|
||||
self._block_size = vllm_config.cache_config.block_size
|
||||
|
||||
self.sended_but_unfinished_reqs: set[str] = set()
|
||||
|
||||
if role == KVConnectorRole.SCHEDULER:
|
||||
self.connector_scheduler = MooncakeStoreConnectorV1Scheduler(
|
||||
vllm_config, self.use_layerwise)
|
||||
else:
|
||||
self.connector_worker = MooncakeEngine(
|
||||
vllm_config,
|
||||
self.use_layerwise,
|
||||
)
|
||||
|
||||
assert self.connector_worker is not None
|
||||
if vllm_config.parallel_config.rank == 0:
|
||||
self.lookup_server = MooncakeLookupServer(
|
||||
self.connector_worker, vllm_config, self.use_layerwise)
|
||||
|
||||
############################################################
|
||||
# Scheduler Side Methods
|
||||
############################################################
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request",
|
||||
num_computed_tokens: int) -> tuple[int, bool]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens)
|
||||
|
||||
def update_state_after_alloc(self, request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int):
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.update_state_after_alloc(
|
||||
request, blocks, num_external_tokens)
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> KVConnectorMetadata:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.build_connector_meta(scheduler_output)
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.request_finished(request, block_ids)
|
||||
|
||||
############################################################
|
||||
# Worker Side Methods
|
||||
############################################################
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
assert self.connector_worker is not None
|
||||
self.connector_worker.register_kv_caches(kv_caches)
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext",
|
||||
**kwargs) -> None:
|
||||
assert self.connector_worker is not None
|
||||
assert isinstance(self._get_connector_metadata(),
|
||||
MooncakeConnectorMetadata)
|
||||
self.connector_worker.start_load_kv(self._get_connector_metadata())
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
"""MooncakeStoreConnector does not do layerwise saving."""
|
||||
if not self.use_layerwise:
|
||||
return
|
||||
self.connector_worker.wait_for_layer_load()
|
||||
|
||||
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
||||
attn_metadata: "AttentionMetadata", **kwargs) -> None:
|
||||
"""MooncakeStoreConnector does not save explicitly."""
|
||||
if not self.use_layerwise:
|
||||
return
|
||||
|
||||
if self.kv_role == "kv_consumer":
|
||||
# Don't do save if the role is kv_consumer
|
||||
return
|
||||
self.connector_worker.save_kv_layer(self._get_connector_metadata())
|
||||
|
||||
def wait_for_save(self):
|
||||
"""MooncakeStoreConnector does not save explicitly."""
|
||||
if self.kv_role == "kv_consumer":
|
||||
# Don't do save if the role is kv_consumer
|
||||
return
|
||||
|
||||
if self.use_layerwise:
|
||||
self.connector_worker.wait_layer_transfer_finish()
|
||||
return
|
||||
|
||||
self.connector_worker.wait_for_save(self._get_connector_metadata())
|
||||
|
||||
def get_finished(self,
|
||||
finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
|
||||
"""Get the finished recving and sending requests."""
|
||||
assert self.connector_worker is not None
|
||||
meta = self._get_connector_metadata()
|
||||
done_sending, done_recving = self.connector_worker.get_finished()
|
||||
sended_and_finished: set[str] = set()
|
||||
for item in list(self.sended_but_unfinished_reqs):
|
||||
if item not in meta.unfinished_request_ids:
|
||||
sended_and_finished.add(item)
|
||||
self.sended_but_unfinished_reqs.remove(item)
|
||||
for item in done_sending:
|
||||
if item in meta.unfinished_request_ids:
|
||||
self.sended_but_unfinished_reqs.add(item)
|
||||
else:
|
||||
sended_and_finished.add(item)
|
||||
|
||||
return sended_and_finished, done_recving
|
||||
|
||||
|
||||
def get_zmq_rpc_path_mooncake(
|
||||
vllm_config: Optional["VllmConfig"] = None, ) -> str:
|
||||
base_url = envs.VLLM_RPC_BASE_PATH
|
||||
# Default to 0 if not configured
|
||||
rpc_port = 0
|
||||
if vllm_config is not None:
|
||||
rpc_port = vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
"mooncake_rpc_port", 0)
|
||||
logger.debug("Base URL: %s, RPC Port: %s", base_url, rpc_port)
|
||||
return f"ipc://{base_url}/mooncake_rpc_port_{rpc_port}"
|
||||
|
||||
|
||||
class MooncakeStoreConnectorV1Scheduler:
|
||||
class KVPoolScheduler:
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig", use_layerwise):
|
||||
self.client = MooncakeLookupClient(vllm_config)
|
||||
self.client = LookupKeyClient(vllm_config)
|
||||
self.use_layerwise = use_layerwise
|
||||
self.kv_role = vllm_config.kv_transfer_config.kv_role
|
||||
self.consumer_is_to_load = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"consumer_is_to_load", False)
|
||||
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"load_async", False)
|
||||
# request_id -> (vllm cached tokes, mooncake cached tokens)
|
||||
# request_id -> (vllm cached tokes, kvpool cached tokens)
|
||||
self.load_specs: dict[str, LoadSpec] = {}
|
||||
self._block_size = vllm_config.cache_config.block_size
|
||||
# request_id -> full_token_ids
|
||||
@@ -201,14 +60,13 @@ class MooncakeStoreConnectorV1Scheduler:
|
||||
return 0, False
|
||||
|
||||
if self._discard_partial_chunks:
|
||||
token_block_end = len(request.prompt_token_ids
|
||||
) // self._block_size * self._block_size
|
||||
token_ids = torch.tensor(
|
||||
request.prompt_token_ids[:token_block_end])
|
||||
token_len = len(request.prompt_token_ids
|
||||
) // self._block_size * self._block_size
|
||||
else:
|
||||
token_ids = torch.tensor(request.prompt_token_ids)
|
||||
token_len = len(request.prompt_token_ids)
|
||||
|
||||
num_external_hit_tokens = self.client.lookup(token_ids)
|
||||
num_external_hit_tokens = self.client.lookup(token_len,
|
||||
request.block_hashes)
|
||||
|
||||
if num_external_hit_tokens == request.num_tokens:
|
||||
num_external_hit_tokens -= 1
|
||||
@@ -216,7 +74,7 @@ class MooncakeStoreConnectorV1Scheduler:
|
||||
need_to_allocate = num_external_hit_tokens - num_computed_tokens
|
||||
|
||||
logger.info(
|
||||
"Reqid: %s, Total tokens %d, mooncake hit tokens: %d, need to load: %d",
|
||||
"Reqid: %s, Total tokens %d, kvpool hit tokens: %d, need to load: %d",
|
||||
request.request_id,
|
||||
request.num_tokens,
|
||||
num_external_hit_tokens,
|
||||
@@ -228,11 +86,11 @@ class MooncakeStoreConnectorV1Scheduler:
|
||||
|
||||
self.load_specs[request.request_id] = LoadSpec(
|
||||
vllm_cached_tokens=num_computed_tokens,
|
||||
mooncake_cached_tokens=num_external_hit_tokens,
|
||||
kvpool_cached_tokens=num_external_hit_tokens,
|
||||
can_load=False,
|
||||
)
|
||||
|
||||
return need_to_allocate, self.load_async
|
||||
return need_to_allocate, self.load_async and not self.use_layerwise
|
||||
|
||||
def update_state_after_alloc(self, request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
@@ -261,10 +119,10 @@ class MooncakeStoreConnectorV1Scheduler:
|
||||
|
||||
assert (
|
||||
num_external_tokens > 0 and num_external_tokens
|
||||
== self.load_specs[request.request_id].mooncake_cached_tokens -
|
||||
== self.load_specs[request.request_id].kvpool_cached_tokens -
|
||||
self.load_specs[request.request_id].vllm_cached_tokens
|
||||
), (f"Mismatch in number of tokens: {num_external_tokens} vs "
|
||||
f"{self.load_specs[request.request_id].mooncake_cached_tokens} - "
|
||||
f"{self.load_specs[request.request_id].kvpool_cached_tokens} - "
|
||||
f"{self.load_specs[request.request_id].vllm_cached_tokens}"
|
||||
f" for request {request.request_id}")
|
||||
|
||||
@@ -289,7 +147,7 @@ class MooncakeStoreConnectorV1Scheduler:
|
||||
self._unfinished_requests.pop(finished_req_id, None)
|
||||
self._unfinished_request_ids.discard(finished_req_id)
|
||||
|
||||
meta = MooncakeConnectorMetadata(self._unfinished_request_ids)
|
||||
meta = AscendConnectorMetadata(self._unfinished_request_ids)
|
||||
|
||||
for request in scheduler_output.scheduled_new_reqs:
|
||||
# Right now, we only load KV for new requests
|
||||
@@ -304,12 +162,15 @@ class MooncakeStoreConnectorV1Scheduler:
|
||||
self._block_size * self._block_size)
|
||||
if self._discard_partial_chunks else len(
|
||||
request.prompt_token_ids))
|
||||
request_tuple = self._unfinished_requests.get(request.req_id)
|
||||
request_real = request_tuple[0] # type: ignore[index]
|
||||
req_meta = ReqMeta.from_request_tracker(
|
||||
request_tracker,
|
||||
self._block_size,
|
||||
load_spec=load_spec,
|
||||
skip_save=force_skip_save,
|
||||
is_last_chunk=len(request_tracker.token_ids)
|
||||
block_hashes=request_real.block_hashes,
|
||||
is_last_chunk=request_tracker.token_len
|
||||
>= last_chunk_tokens_num,
|
||||
discard_partial_chunks=self._discard_partial_chunks,
|
||||
)
|
||||
@@ -317,33 +178,14 @@ class MooncakeStoreConnectorV1Scheduler:
|
||||
meta.add_request(req_meta)
|
||||
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
if isinstance(cached_reqs, list) and not force_skip_save:
|
||||
for i, req in enumerate(cached_reqs):
|
||||
request_tracker = self._request_trackers[req.req_id]
|
||||
request_tracker.update(req.new_token_ids, req.new_block_ids)
|
||||
last_chunk_tokens_num = ((len(req.prompt_token_ids) //
|
||||
self._block_size * self._block_size)
|
||||
if self._discard_partial_chunks else
|
||||
len(req.prompt_token_ids))
|
||||
req_meta = ReqMeta.from_request_tracker(
|
||||
request_tracker,
|
||||
self._block_size,
|
||||
load_spec=None,
|
||||
skip_save=force_skip_save,
|
||||
is_last_chunk=len(request_tracker.token_ids)
|
||||
>= last_chunk_tokens_num,
|
||||
discard_partial_chunks=self._discard_partial_chunks,
|
||||
)
|
||||
if req_meta is not None:
|
||||
meta.add_request(req_meta)
|
||||
elif not force_skip_save:
|
||||
if not force_skip_save:
|
||||
for i, req_id in enumerate(cached_reqs.req_ids):
|
||||
request_tracker = self._request_trackers[req_id]
|
||||
num_new_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
req_tuple = self._unfinished_requests.get(req_id)
|
||||
if req_tuple:
|
||||
request = req_tuple[0]
|
||||
num_current_tokens = len(request_tracker.token_ids)
|
||||
num_current_tokens = request_tracker.token_len
|
||||
new_token_ids = request.all_token_ids[
|
||||
num_current_tokens:num_current_tokens + num_new_tokens]
|
||||
else:
|
||||
@@ -355,8 +197,7 @@ class MooncakeStoreConnectorV1Scheduler:
|
||||
continue
|
||||
request_tracker.update(new_token_ids, new_block_ids)
|
||||
# decode not save
|
||||
if len(request_tracker.token_ids) > len(
|
||||
request.prompt_token_ids):
|
||||
if request_tracker.token_len > len(request.prompt_token_ids):
|
||||
continue
|
||||
|
||||
last_chunk_tokens_num = ((len(request.prompt_token_ids) //
|
||||
@@ -368,7 +209,8 @@ class MooncakeStoreConnectorV1Scheduler:
|
||||
self._block_size,
|
||||
load_spec=None,
|
||||
skip_save=force_skip_save,
|
||||
is_last_chunk=len(request_tracker.token_ids)
|
||||
block_hashes=request.block_hashes,
|
||||
is_last_chunk=request_tracker.token_len
|
||||
>= last_chunk_tokens_num,
|
||||
discard_partial_chunks=self._discard_partial_chunks,
|
||||
)
|
||||
@@ -384,15 +226,14 @@ class MooncakeStoreConnectorV1Scheduler:
|
||||
load_spec = self.load_specs.pop(request_id, None)
|
||||
if not load_spec:
|
||||
continue
|
||||
num_tokens_to_compute = load_spec.mooncake_cached_tokens
|
||||
num_tokens_to_compute = load_spec.kvpool_cached_tokens
|
||||
if (num_tokens_to_compute % self._block_size
|
||||
!= 0) and (num_tokens_to_compute
|
||||
== len(request.prompt_token_ids) - 1):
|
||||
num_tokens_to_compute = num_tokens_to_compute + 1
|
||||
request_tracker = RequestTracker(
|
||||
req_id=request_id,
|
||||
token_ids=request.prompt_token_ids[:num_tokens_to_compute].
|
||||
copy(),
|
||||
token_len=num_tokens_to_compute,
|
||||
allocated_block_ids=block_ids,
|
||||
num_saved_tokens=0,
|
||||
)
|
||||
@@ -404,6 +245,7 @@ class MooncakeStoreConnectorV1Scheduler:
|
||||
self._block_size,
|
||||
load_spec=load_spec,
|
||||
skip_save=None,
|
||||
block_hashes=request.block_hashes,
|
||||
discard_partial_chunks=self._discard_partial_chunks,
|
||||
)
|
||||
if req_meta is not None:
|
||||
@@ -431,12 +273,12 @@ class MooncakeStoreConnectorV1Scheduler:
|
||||
return delay_free_blocks, None
|
||||
|
||||
|
||||
class MooncakeLookupClient:
|
||||
class LookupKeyClient:
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig"):
|
||||
self.encoder = MsgpackEncoder()
|
||||
self.ctx = zmq.Context() # type: ignore[attr-defined]
|
||||
socket_path = get_zmq_rpc_path_mooncake(vllm_config)
|
||||
socket_path = get_zmq_rpc_path_lookup(vllm_config)
|
||||
self.socket = make_zmq_socket(
|
||||
self.ctx,
|
||||
socket_path,
|
||||
@@ -444,9 +286,12 @@ class MooncakeLookupClient:
|
||||
bind=False,
|
||||
)
|
||||
|
||||
def lookup(self, token_ids: torch.Tensor) -> int:
|
||||
request = self.encoder.encode(token_ids)
|
||||
self.socket.send_multipart(request, copy=False)
|
||||
def lookup(self, token_len: int, block_hashes: list[BlockHash]) -> int:
|
||||
hash_strs = [h.hex() for h in block_hashes]
|
||||
hash_frames = self.encoder.encode(hash_strs)
|
||||
token_len_bytes = token_len.to_bytes(4, byteorder="big")
|
||||
all_frames = [token_len_bytes] + list(hash_frames)
|
||||
self.socket.send_multipart(all_frames, copy=False)
|
||||
resp = self.socket.recv()
|
||||
result = int.from_bytes(resp, "big")
|
||||
return result
|
||||
@@ -455,39 +300,19 @@ class MooncakeLookupClient:
|
||||
self.socket.close(linger=0)
|
||||
|
||||
|
||||
class MooncakeLookupServer:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mooncake_engine: MooncakeEngine,
|
||||
vllm_config: "VllmConfig",
|
||||
use_layerwise: bool,
|
||||
):
|
||||
self.decoder = MsgpackDecoder(torch.Tensor)
|
||||
self.ctx = zmq.Context() # type: ignore[attr-defined]
|
||||
socket_path = get_zmq_rpc_path_mooncake(vllm_config)
|
||||
self.socket = make_zmq_socket(
|
||||
self.ctx,
|
||||
socket_path,
|
||||
zmq.REP, # type: ignore[attr-defined]
|
||||
bind=True,
|
||||
)
|
||||
|
||||
self.mooncake_engine = mooncake_engine
|
||||
self.running = True
|
||||
|
||||
def process_request():
|
||||
while self.running:
|
||||
frames = self.socket.recv_multipart(copy=False)
|
||||
token_ids = self.decoder.decode(frames)
|
||||
result = self.mooncake_engine.lookup_scheduler(
|
||||
token_ids, use_layerwise)
|
||||
response = result.to_bytes(4, "big")
|
||||
self.socket.send(response)
|
||||
|
||||
self.thread = threading.Thread(target=process_request, daemon=True)
|
||||
self.thread.start()
|
||||
|
||||
def close(self):
|
||||
self.socket.close(linger=0)
|
||||
# TODO: close the thread!
|
||||
def get_zmq_rpc_path_lookup(
|
||||
vllm_config: Optional["VllmConfig"] = None, ) -> str:
|
||||
base_url = envs.VLLM_RPC_BASE_PATH
|
||||
# Default to 0 if not configured
|
||||
rpc_port = 0
|
||||
if vllm_config is not None:
|
||||
extra_config = vllm_config.kv_transfer_config.kv_connector_extra_config
|
||||
if "lookup_rpc_port" in extra_config:
|
||||
rpc_port = extra_config["lookup_rpc_port"]
|
||||
elif "mooncake_rpc_port" in extra_config:
|
||||
rpc_port = extra_config["mooncake_rpc_port"]
|
||||
logger.warning(
|
||||
"It is recommended to use the lookup_rpc_port, as the mooncake_rpc_port will be removed in the future."
|
||||
)
|
||||
logger.debug("Base URL: %s, RPC Port: %s", base_url, rpc_port)
|
||||
return f"ipc://{base_url}/lookup_rpc_port_{rpc_port}"
|
||||
@@ -1,25 +1,33 @@
|
||||
# Standard
|
||||
import math
|
||||
import threading
|
||||
import time
|
||||
from typing import Generator, List, Optional, Union
|
||||
from typing import Dict, Generator, Optional, Type
|
||||
|
||||
# Third Party
|
||||
import torch
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils import logger
|
||||
from vllm.utils.torch_utils import get_kv_cache_torch_dtype
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash
|
||||
|
||||
from vllm_ascend.distributed.mooncake.config_data import (
|
||||
ChunkedTokenDatabase, LasyerMultiBlockReqMeta, MooncakeConnectorMetadata,
|
||||
MooncakeEngineMetadata)
|
||||
from vllm_ascend.distributed.mooncake.kv_transfer import (
|
||||
from vllm_ascend.distributed.kvpool.backend.backend import Backend
|
||||
from vllm_ascend.distributed.kvpool.backend.memcache_backend import \
|
||||
MemcacheBackend
|
||||
from vllm_ascend.distributed.kvpool.backend.mooncake_backend import \
|
||||
MooncakeBackend
|
||||
from vllm_ascend.distributed.kvpool.config_data import (
|
||||
AscendConnectorMetadata, ChunkedTokenDatabase, KeyMetadata,
|
||||
LasyerMultiBlockReqMeta)
|
||||
from vllm_ascend.distributed.kvpool.kv_transfer import (
|
||||
KVCacheStoreLayerRecvingThread, KVCacheStoreLayerSendingThread,
|
||||
KVCacheStoreRecvingThread, KVCacheStoreSendingThread, KVTransferThread)
|
||||
from vllm_ascend.distributed.mooncake.mooncake_store import Mooncakestore
|
||||
|
||||
backend_map: Dict[str, Type[Backend]] = {
|
||||
"mooncake": MooncakeBackend,
|
||||
"memcache": MemcacheBackend,
|
||||
}
|
||||
|
||||
|
||||
class MooncakeEngine:
|
||||
class KVPoolWorker:
|
||||
#The main class for the cache engine.
|
||||
|
||||
def __init__(
|
||||
@@ -29,6 +37,7 @@ class MooncakeEngine:
|
||||
):
|
||||
model_config = vllm_config.model_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
self.dp_rank = parallel_config.data_parallel_rank
|
||||
self.use_mla = False
|
||||
if (hasattr(model_config, "use_mla")
|
||||
and isinstance(model_config.use_mla, bool)
|
||||
@@ -40,37 +49,37 @@ class MooncakeEngine:
|
||||
self.kv_role = vllm_config.kv_transfer_config.kv_role
|
||||
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"load_async", False)
|
||||
self.register_buffer = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"register_buffer", False)
|
||||
self.backend = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"backend", "mooncake")
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.current_layer = 0
|
||||
# self.use_mla = first_kv_cache_tuple[0].size(
|
||||
# -1) != first_kv_cache_tuple[1].size(-1)
|
||||
self.num_layers = model_config.get_num_layers(parallel_config)
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
num_kv_head = model_config.get_num_kv_heads(parallel_config)
|
||||
head_size = model_config.get_head_size()
|
||||
kv_dtype = get_kv_cache_torch_dtype(
|
||||
vllm_config.cache_config.cache_dtype, model_config.dtype)
|
||||
self.hidden_dim_size = num_kv_head * head_size
|
||||
|
||||
if self.use_mla:
|
||||
kv_shape = (self.num_layers, 1, self.block_size, 1, head_size)
|
||||
self.num_kv_head = 1
|
||||
else:
|
||||
kv_shape = (self.num_layers, 2, self.block_size, num_kv_head,
|
||||
head_size)
|
||||
self.metadata = MooncakeEngineMetadata(
|
||||
self.num_kv_head = model_config.get_total_num_kv_heads()
|
||||
|
||||
if self.num_kv_head < self.tp_size:
|
||||
self.put_step = self.tp_size // self.num_kv_head
|
||||
self.head_or_tp_rank = self.tp_rank // self.put_step
|
||||
else:
|
||||
self.head_or_tp_rank = self.tp_rank
|
||||
self.put_step = 1
|
||||
|
||||
self.metadata = KeyMetadata(
|
||||
model_config.model,
|
||||
parallel_config.world_size,
|
||||
parallel_config.rank,
|
||||
kv_dtype,
|
||||
kv_shape,
|
||||
self.block_size,
|
||||
self.use_mla,
|
||||
self.head_or_tp_rank,
|
||||
)
|
||||
|
||||
self.token_database = ChunkedTokenDatabase(self.metadata)
|
||||
self.token_database = ChunkedTokenDatabase(self.metadata,
|
||||
self.block_size,
|
||||
self.use_mla)
|
||||
|
||||
self.m_store = Mooncakestore(parallel_config)
|
||||
real_backend = backend_map.get(self.backend.lower())
|
||||
self.m_store = real_backend( # type: ignore[misc]
|
||||
parallel_config)
|
||||
|
||||
self.kv_send_thread: Optional[KVTransferThread] = None
|
||||
self.kv_recv_thread: Optional[KVTransferThread] = None
|
||||
@@ -108,94 +117,83 @@ class MooncakeEngine:
|
||||
|
||||
self.kv_caches = kv_caches
|
||||
self.kv_caches_base_addr = []
|
||||
ptrs = []
|
||||
lengths = []
|
||||
for cache_or_caches in kv_caches.values():
|
||||
# Normalize to always be a list of caches
|
||||
if self.use_mla:
|
||||
for i, cache in enumerate(cache_or_caches, 0):
|
||||
base_addr = cache.data_ptr()
|
||||
self.kv_caches_base_addr.append(base_addr)
|
||||
if self.register_buffer:
|
||||
region_len = self.num_blocks * self.block_len[i % 2]
|
||||
self._register(base_addr, region_len)
|
||||
region_len = self.num_blocks * self.block_len[i % 2]
|
||||
ptrs.append(base_addr)
|
||||
lengths.append(region_len)
|
||||
else:
|
||||
cache_list = [cache_or_caches
|
||||
] if self.use_mla else cache_or_caches
|
||||
for cache in cache_list:
|
||||
base_addr = cache.data_ptr()
|
||||
self.kv_caches_base_addr.append(base_addr)
|
||||
if self.register_buffer:
|
||||
region_len = self.num_blocks * self.block_len[0]
|
||||
self._register(base_addr, region_len)
|
||||
region_len = self.num_blocks * self.block_len[0]
|
||||
ptrs.append(base_addr)
|
||||
lengths.append(region_len)
|
||||
self.m_store.register_buffer(ptrs, lengths)
|
||||
self.token_database.set_kv_caches_base_addr(self.kv_caches_base_addr)
|
||||
self.token_database.set_block_len(self.block_len)
|
||||
|
||||
if self.use_layerwise:
|
||||
self.get_event = threading.Event()
|
||||
if self.kv_role in ['kv_producer', 'kv_both']:
|
||||
ready_event_sending = threading.Event()
|
||||
self.kv_send_thread = KVCacheStoreLayerSendingThread(
|
||||
self.tp_rank, self.tp_size, self.m_store,
|
||||
self.kv_caches_base_addr, self.token_database,
|
||||
self.block_len, self.block_size, ready_event_sending,
|
||||
self.num_layers)
|
||||
self.m_store, self.token_database, self.tp_rank,
|
||||
self.put_step, ready_event_sending, self.num_layers)
|
||||
self.kv_send_thread.start()
|
||||
ready_event = threading.Event()
|
||||
self.kv_recv_thread = KVCacheStoreLayerRecvingThread(
|
||||
self.tp_rank, self.tp_size, self.m_store,
|
||||
self.kv_caches_base_addr, self.token_database, self.block_len,
|
||||
self.block_size, ready_event, self.get_event)
|
||||
self.m_store, self.token_database, self.tp_rank, ready_event,
|
||||
self.get_event)
|
||||
self.kv_recv_thread.start()
|
||||
ready_event.wait()
|
||||
else:
|
||||
if self.kv_role in ['kv_producer', 'kv_both']:
|
||||
ready_event_sending = threading.Event()
|
||||
self.kv_send_thread = KVCacheStoreSendingThread(
|
||||
self.tp_rank, self.tp_size, self.m_store,
|
||||
self.kv_caches_base_addr, self.token_database,
|
||||
self.block_len, self.block_size, ready_event_sending)
|
||||
self.m_store, self.token_database, self.tp_rank,
|
||||
self.put_step, ready_event_sending)
|
||||
self.kv_send_thread.start()
|
||||
if self.load_async:
|
||||
ready_event = threading.Event()
|
||||
self.kv_recv_thread = KVCacheStoreRecvingThread(
|
||||
self.tp_rank, self.tp_size, self.m_store,
|
||||
self.kv_caches_base_addr, self.token_database,
|
||||
self.block_len, self.block_size, ready_event)
|
||||
self.m_store, self.token_database, self.tp_rank,
|
||||
ready_event)
|
||||
self.kv_recv_thread.start()
|
||||
ready_event.wait()
|
||||
|
||||
def _register(self, ptr, length):
|
||||
logger.debug(
|
||||
"Registering KV cache: ptr=0x%x, length=%d, num_blocks=%d, "
|
||||
"block_lens=%s", ptr, length, self.num_blocks, self.block_len)
|
||||
try:
|
||||
self.m_store.register_buffer(ptr, length)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Mooncake memory registration failed. Error is: {e}")
|
||||
|
||||
def start_load_kv(self, metadata: MooncakeConnectorMetadata):
|
||||
def start_load_kv(self, metadata: AscendConnectorMetadata):
|
||||
self.current_layer = 0
|
||||
self.layerwise_retrievers = []
|
||||
for request in metadata.requests:
|
||||
load_spec = request.load_spec
|
||||
if load_spec is None or not load_spec.can_load: #load =0
|
||||
continue
|
||||
tokens = request.token_ids
|
||||
token_len = request.token_len_chunk
|
||||
req_id = request.req_id
|
||||
if (load_spec.mooncake_cached_tokens % self.block_size
|
||||
!= 0) and (load_spec.mooncake_cached_tokens
|
||||
== tokens.shape[0] - 1):
|
||||
tokens = tokens[:request.load_spec.mooncake_cached_tokens + 1]
|
||||
if (load_spec.kvpool_cached_tokens % self.block_size
|
||||
!= 0) and (load_spec.kvpool_cached_tokens
|
||||
== token_len - 1):
|
||||
token_len = request.load_spec.kvpool_cached_tokens + 1
|
||||
else:
|
||||
tokens = tokens[:request.load_spec.mooncake_cached_tokens]
|
||||
masked_token_count = (request.load_spec.vllm_cached_tokens //
|
||||
self.block_size * self.block_size)
|
||||
token_mask = torch.ones_like(tokens, dtype=torch.bool)
|
||||
token_mask[:masked_token_count] = False
|
||||
token_len = request.load_spec.kvpool_cached_tokens
|
||||
mask_num = (request.load_spec.vllm_cached_tokens //
|
||||
self.block_size * self.block_size)
|
||||
if self.use_layerwise:
|
||||
layerwise_retriever = self.retrieve_layer(
|
||||
req_id,
|
||||
tokens,
|
||||
token_len,
|
||||
request.block_ids,
|
||||
token_mask,
|
||||
request.block_hashes,
|
||||
mask_num,
|
||||
)
|
||||
next(layerwise_retriever) # first layer load
|
||||
self.layerwise_retrievers.append(layerwise_retriever)
|
||||
@@ -203,102 +201,84 @@ class MooncakeEngine:
|
||||
if self.load_async:
|
||||
self.kv_recv_thread.add_request( # type: ignore[union-attr]
|
||||
req_id,
|
||||
tokens,
|
||||
token_len,
|
||||
request.block_ids,
|
||||
token_mask,
|
||||
request.block_hashes,
|
||||
mask_num,
|
||||
)
|
||||
else:
|
||||
if self.m_store.config.use_ascend_direct:
|
||||
addr_list = []
|
||||
size_list = []
|
||||
key_list = []
|
||||
blockIds = []
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens, token_mask):
|
||||
addr, size, block_id = self.prepare_value(
|
||||
start, end, request.block_ids)
|
||||
key_list.append(key.to_string())
|
||||
addr_list.append(addr)
|
||||
size_list.append(size)
|
||||
blockIds.append(block_id)
|
||||
self.m_store.get_batch(key_list, addr_list, size_list,
|
||||
blockIds)
|
||||
else:
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens, token_mask):
|
||||
addr, size, _ = self.prepare_value(
|
||||
start, end, request.block_ids)
|
||||
self.m_store.get(key, addr, size)
|
||||
|
||||
def prepare_value(self, start: int, end: int, block_ids: list[int]):
|
||||
addr_list = []
|
||||
size_list = []
|
||||
block_id = block_ids[start // self.block_size]
|
||||
for index, base_addr in enumerate(self.kv_caches_base_addr):
|
||||
block_len = (self.block_len[index % 2]
|
||||
if self.use_mla else self.block_len[0])
|
||||
|
||||
addr = base_addr + block_id * block_len
|
||||
length = int(block_len / self.block_size * (end - start))
|
||||
addr_list.append(addr)
|
||||
size_list.append(length)
|
||||
return addr_list, size_list, block_id
|
||||
addr_list = []
|
||||
size_list = []
|
||||
key_list = []
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
token_len, request.block_hashes, mask_num):
|
||||
addr, size, _ = self.token_database.prepare_value(
|
||||
start, end, request.block_ids)
|
||||
key_list.append(key.to_string())
|
||||
addr_list.append(addr)
|
||||
size_list.append(size)
|
||||
key_list_c = key_list[self.tp_rank % len(
|
||||
key_list):] + key_list[:self.tp_rank % len(key_list)]
|
||||
addr_list_c = addr_list[self.tp_rank %
|
||||
len(addr_list
|
||||
):] + addr_list[:self.tp_rank %
|
||||
len(addr_list)]
|
||||
size_list_c = size_list[self.tp_rank %
|
||||
len(size_list
|
||||
):] + size_list[:self.tp_rank %
|
||||
len(size_list)]
|
||||
self.m_store.get(key_list_c, addr_list_c, size_list_c)
|
||||
|
||||
def wait_for_layer_load(self) -> None:
|
||||
"""MooncakeConnector does not do layerwise saving."""
|
||||
for layerwise_retriever in self.layerwise_retrievers:
|
||||
ret_token_mask = next(layerwise_retriever)
|
||||
if self.current_layer == self.num_layers - 1:
|
||||
assert ret_token_mask is not None
|
||||
num_retrieved_tokens = ret_token_mask.sum().item()
|
||||
logger.info(f"Retrieved {num_retrieved_tokens} tokens")
|
||||
logger.debug(f"Retrieved {num_retrieved_tokens} tokens")
|
||||
|
||||
def save_kv_layer(self,
|
||||
connector_metadata: MooncakeConnectorMetadata) -> None:
|
||||
"""MooncakeConnector does not save explicitly."""
|
||||
connector_metadata: AscendConnectorMetadata) -> None:
|
||||
if self.current_layer == 0:
|
||||
self.layerwise_storers = []
|
||||
for request in connector_metadata.requests:
|
||||
save_spec = request.save_spec
|
||||
if save_spec is None or not save_spec.can_save:
|
||||
can_save = request.can_save
|
||||
if can_save is None or not can_save:
|
||||
continue
|
||||
|
||||
token_ids = request.token_ids
|
||||
token_len = request.token_len_chunk
|
||||
req_id = request.req_id
|
||||
assert isinstance(token_ids, torch.Tensor)
|
||||
assert token_ids.is_cpu
|
||||
|
||||
# TODO: whether need to remov saveThread
|
||||
# no lookup, skipmask
|
||||
skip_leading_tokens = max(
|
||||
self.lookup(token_ids, self.use_layerwise),
|
||||
save_spec.skip_leading_tokens,
|
||||
)
|
||||
if skip_leading_tokens == len(token_ids):
|
||||
skip_leading_tokens = self.lookup(token_len,
|
||||
request.block_hashes,
|
||||
self.use_layerwise)
|
||||
if skip_leading_tokens == token_len:
|
||||
if request.is_last_chunk:
|
||||
self.kv_send_thread.set_finished_request( # type: ignore[union-attr]
|
||||
req_id)
|
||||
continue # skip this request
|
||||
|
||||
skip_leading_tokens = (skip_leading_tokens // self.block_size *
|
||||
self.block_size)
|
||||
mask_num = (skip_leading_tokens // self.block_size *
|
||||
self.block_size)
|
||||
|
||||
store_mask = torch.ones_like(token_ids, dtype=torch.bool)
|
||||
store_mask[:skip_leading_tokens] = False
|
||||
logger.info(
|
||||
"Storing KV cache for %d out of %d tokens "
|
||||
"(skip_leading_tokens=%d) for request %s",
|
||||
len(token_ids) - skip_leading_tokens,
|
||||
len(token_ids),
|
||||
token_len - skip_leading_tokens,
|
||||
token_len,
|
||||
skip_leading_tokens,
|
||||
request.req_id,
|
||||
)
|
||||
|
||||
layerwise_storer = self.store_layer(
|
||||
req_id,
|
||||
token_ids,
|
||||
mask=store_mask,
|
||||
token_len,
|
||||
block_hashes=request.block_hashes,
|
||||
mask_num=mask_num,
|
||||
block_ids=request.block_ids,
|
||||
is_last_chunk=request.is_last_chunk,
|
||||
)
|
||||
self.layerwise_storers.append(layerwise_storer)
|
||||
for layerwise_storer in self.layerwise_storers:
|
||||
@@ -306,59 +286,53 @@ class MooncakeEngine:
|
||||
next(layerwise_storer)
|
||||
except Exception:
|
||||
raise
|
||||
self.current_layer = self.current_layer + 1
|
||||
self.current_layer = self.current_layer + 1
|
||||
|
||||
def wait_for_save(self, connector_metadata: MooncakeConnectorMetadata):
|
||||
"""MooncakeConnector does not save explicitly."""
|
||||
def wait_for_save(self, connector_metadata: AscendConnectorMetadata):
|
||||
for request in connector_metadata.requests:
|
||||
save_spec = request.save_spec
|
||||
if save_spec is None or not save_spec.can_save:
|
||||
can_save = request.can_save
|
||||
if can_save is None or not can_save:
|
||||
continue
|
||||
|
||||
token_ids = request.token_ids
|
||||
token_len = request.token_len_chunk
|
||||
req_id = request.req_id
|
||||
assert isinstance(token_ids, torch.Tensor)
|
||||
assert token_ids.is_cpu
|
||||
|
||||
skip_leading_tokens = max(
|
||||
self.lookup(token_ids, self.use_layerwise),
|
||||
save_spec.skip_leading_tokens,
|
||||
)
|
||||
if skip_leading_tokens == len(token_ids):
|
||||
skip_leading_tokens = self.lookup(token_len, request.block_hashes,
|
||||
self.use_layerwise)
|
||||
if skip_leading_tokens == token_len:
|
||||
if request.is_last_chunk:
|
||||
self.kv_send_thread.set_finished_request( # type: ignore[union-attr]
|
||||
req_id)
|
||||
continue # skip this request
|
||||
|
||||
skip_leading_tokens = (skip_leading_tokens // self.block_size *
|
||||
self.block_size)
|
||||
|
||||
store_mask = torch.ones_like(token_ids, dtype=torch.bool)
|
||||
store_mask[:skip_leading_tokens] = False
|
||||
mask_num = (skip_leading_tokens // self.block_size *
|
||||
self.block_size)
|
||||
|
||||
logger.info(
|
||||
"Storing KV cache for %d out of %d tokens "
|
||||
"(skip_leading_tokens=%d) for request %s",
|
||||
len(token_ids) - skip_leading_tokens,
|
||||
len(token_ids),
|
||||
token_len - skip_leading_tokens,
|
||||
token_len,
|
||||
skip_leading_tokens,
|
||||
request.req_id,
|
||||
)
|
||||
|
||||
self.kv_send_thread.add_request( # type: ignore[union-attr]
|
||||
req_id,
|
||||
token_ids,
|
||||
token_len,
|
||||
request.block_ids,
|
||||
store_mask,
|
||||
request.block_hashes,
|
||||
mask_num,
|
||||
request.is_last_chunk,
|
||||
)
|
||||
|
||||
def retrieve_layer(
|
||||
self,
|
||||
req_id: str,
|
||||
tokens: torch.Tensor,
|
||||
token_len: int,
|
||||
block_ids: list[int],
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
block_hashes: list[BlockHash],
|
||||
mask_num: int = 0,
|
||||
) -> Generator[Optional[torch.Tensor], None, None]:
|
||||
"""
|
||||
Retrieve the KV cache in a layerwise manner.
|
||||
@@ -376,20 +350,16 @@ class MooncakeEngine:
|
||||
be the boolean mask indicating which tokens are retrieved and will
|
||||
only be returned in the last iteration.
|
||||
"""
|
||||
num_required_tokens = token_len - mask_num
|
||||
|
||||
if mask is not None:
|
||||
num_required_tokens = torch.sum(mask).item()
|
||||
else:
|
||||
num_required_tokens = len(tokens)
|
||||
|
||||
ret_mask = torch.zeros_like(tokens, dtype=torch.bool, device="cpu")
|
||||
ret_mask = torch.zeros(token_len, dtype=torch.bool, device="cpu")
|
||||
|
||||
starts = []
|
||||
ends = []
|
||||
keys = []
|
||||
first_flag = True
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens, mask):
|
||||
token_len, block_hashes, mask_num):
|
||||
keys_multi_layer = key.split_layers(self.num_layers)
|
||||
starts.append(start)
|
||||
ends.append(end)
|
||||
@@ -421,16 +391,18 @@ class MooncakeEngine:
|
||||
retrieved_tokens = torch.sum(ret_mask)
|
||||
logger.debug(f"Retrieved {retrieved_tokens} "
|
||||
f"out of {num_required_tokens} "
|
||||
f"out of total {len(tokens)} tokens")
|
||||
f"out of total {token_len} tokens")
|
||||
|
||||
yield ret_mask
|
||||
|
||||
def store_layer(
|
||||
self,
|
||||
req_id: str,
|
||||
tokens: torch.Tensor,
|
||||
token_len: int,
|
||||
block_ids: list[int],
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
block_hashes: list[BlockHash],
|
||||
is_last_chunk: bool,
|
||||
mask_num: int = 0,
|
||||
) -> Generator[None, None, None]:
|
||||
"""
|
||||
Store the KV cache in a layerwise manner.
|
||||
@@ -452,17 +424,13 @@ class MooncakeEngine:
|
||||
storage backends. In the last iteration, it puts the memory objects
|
||||
of the last layer to the storage backends.
|
||||
"""
|
||||
|
||||
if mask is not None:
|
||||
num_stored_tokens = torch.sum(mask).item()
|
||||
else:
|
||||
num_stored_tokens = len(tokens)
|
||||
num_stored_tokens = token_len - mask_num
|
||||
|
||||
starts = []
|
||||
ends = []
|
||||
keys = []
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens, mask):
|
||||
token_len, block_hashes, mask_num):
|
||||
keys_multi_layer = key.split_layers(self.num_layers)
|
||||
starts.append(start)
|
||||
ends.append(end)
|
||||
@@ -473,7 +441,7 @@ class MooncakeEngine:
|
||||
for layer_id, keys_multi_chunk in enumerate(keys):
|
||||
req_meta = LasyerMultiBlockReqMeta(req_id, keys_multi_chunk,
|
||||
starts, ends, block_ids,
|
||||
layer_id)
|
||||
layer_id, is_last_chunk)
|
||||
self.kv_send_thread.add_request( # type: ignore[union-attr, call-arg]
|
||||
req_meta) # type: ignore[union-attr, call-arg, arg-type]
|
||||
yield
|
||||
@@ -481,7 +449,7 @@ class MooncakeEngine:
|
||||
for layer_id in range(self.num_layers):
|
||||
yield
|
||||
logger.debug(
|
||||
f"Stored {num_stored_tokens} out of total {len(tokens)} tokens")
|
||||
f"Stored {num_stored_tokens} out of total {token_len} tokens")
|
||||
|
||||
def get_finished(self) -> tuple[set[str], set[str]]:
|
||||
done_sending = (
|
||||
@@ -500,13 +468,10 @@ class MooncakeEngine:
|
||||
self.tp_rank)
|
||||
return done_sending, done_recving
|
||||
|
||||
def wait_layer_transfer_finish(self):
|
||||
time.sleep(10)
|
||||
pass
|
||||
|
||||
def lookup(
|
||||
self,
|
||||
tokens: Union[torch.Tensor, List[int]],
|
||||
token_len: int,
|
||||
block_hashes: list[BlockHash],
|
||||
use_layerwise: bool,
|
||||
) -> int:
|
||||
"""
|
||||
@@ -517,34 +482,24 @@ class MooncakeEngine:
|
||||
end = 0
|
||||
keys = []
|
||||
try:
|
||||
if use_layerwise:
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens):
|
||||
starts = []
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
token_len, block_hashes):
|
||||
if use_layerwise:
|
||||
keys_multi_layer = key.split_layers(self.num_layers)
|
||||
for item in keys_multi_layer:
|
||||
keys.append(item.to_string())
|
||||
# batch is_exists
|
||||
ress = self.m_store.batch_exists(keys)
|
||||
res = 1
|
||||
for value in ress:
|
||||
if value != 1:
|
||||
res = 0
|
||||
break
|
||||
if res == 1:
|
||||
continue
|
||||
else:
|
||||
return start
|
||||
else:
|
||||
starts = []
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens):
|
||||
else:
|
||||
keys.append(key.to_string())
|
||||
starts.append(start)
|
||||
res = self.m_store.batch_exists(
|
||||
keys) # type: ignore[assignment]
|
||||
for index, value in enumerate(res): # type: ignore[arg-type]
|
||||
if value != 1:
|
||||
return starts[index]
|
||||
starts.append(start)
|
||||
|
||||
res = self.m_store.exists(keys) # type: ignore[assignment]
|
||||
|
||||
if use_layerwise:
|
||||
res = self.check_all_layers_exists(res, self.num_layers)
|
||||
for index, value in enumerate(res): # type: ignore[arg-type]
|
||||
if value != 1:
|
||||
return starts[index]
|
||||
# all tokens where found, return the maximal end
|
||||
except Exception as e:
|
||||
logger.error(f"Remote connection failed in contains: {e}")
|
||||
@@ -553,7 +508,8 @@ class MooncakeEngine:
|
||||
|
||||
def lookup_scheduler(
|
||||
self,
|
||||
tokens: Union[torch.Tensor, List[int]],
|
||||
token_len: int,
|
||||
block_hashes: list[BlockHash],
|
||||
use_layerwise: bool,
|
||||
) -> int:
|
||||
"""
|
||||
@@ -564,59 +520,59 @@ class MooncakeEngine:
|
||||
end = 0
|
||||
keys = []
|
||||
try:
|
||||
if use_layerwise:
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens):
|
||||
starts = []
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
token_len, block_hashes):
|
||||
if use_layerwise:
|
||||
keys_multi_layer = key.split_layers(self.num_layers)
|
||||
for item in keys_multi_layer:
|
||||
keys.append(item.to_string())
|
||||
# batch is_exists
|
||||
ress = self.m_store.batch_exists(keys)
|
||||
res = 1
|
||||
for value in ress:
|
||||
if value != 1:
|
||||
res = 0
|
||||
break
|
||||
if res == 1:
|
||||
continue
|
||||
else:
|
||||
return start
|
||||
else:
|
||||
starts = []
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens):
|
||||
else:
|
||||
keys.append(key.to_string())
|
||||
starts.append(start)
|
||||
multi_tp_keys = keys[:]
|
||||
for i in range(1, self.tp_size):
|
||||
for item in keys:
|
||||
new_str = item.replace( # type: ignore[attr-defined]
|
||||
"@0", f"@{i}", 1)
|
||||
multi_tp_keys.append(new_str)
|
||||
res = self.m_store.batch_exists(
|
||||
multi_tp_keys) # type: ignore[assignment]
|
||||
num_block = len(keys)
|
||||
multi_tp_values = [
|
||||
res[i * num_block:(i + 1) *
|
||||
num_block] # type: ignore[index]
|
||||
for i in range(self.tp_size)
|
||||
]
|
||||
index = self.find_min_first_non_one_index(multi_tp_values)
|
||||
if index != -1:
|
||||
return starts[index]
|
||||
starts.append(start)
|
||||
|
||||
multi_tp_keys = keys[:]
|
||||
for i in range(1, min(self.tp_size, self.num_kv_head)):
|
||||
for item in keys:
|
||||
new_str = item.replace( # type: ignore[attr-defined]
|
||||
"@head_or_tp_rank:0", f"@head_or_tp_rank:{i}", 1)
|
||||
multi_tp_keys.append(new_str)
|
||||
|
||||
res = self.m_store.exists(
|
||||
multi_tp_keys) # type: ignore[assignment]
|
||||
num_block = len(keys)
|
||||
if use_layerwise:
|
||||
res = self.check_all_layers_exists(res, self.num_layers)
|
||||
num_block = len(keys) // self.num_layers
|
||||
multi_tp_values = [
|
||||
res[i * num_block:(i + 1) * num_block] # type: ignore[index]
|
||||
for i in range(min(self.tp_size, self.num_kv_head))
|
||||
]
|
||||
index = self.find_min_first_non_one_index(multi_tp_values)
|
||||
if index != -1:
|
||||
return starts[index]
|
||||
# all tokens where found, return the maximal end
|
||||
except Exception as e:
|
||||
logger.error(f"Remote connection failed in contains: {e}")
|
||||
return start
|
||||
return end
|
||||
|
||||
def check_all_layers_exists(self, res: list[int],
|
||||
num_layers: int) -> list[int]:
|
||||
total_chunks = len(res) // num_layers
|
||||
result = []
|
||||
|
||||
for chunk_idx in range(total_chunks):
|
||||
start = chunk_idx * num_layers
|
||||
end = start + num_layers
|
||||
chunk = res[start:end]
|
||||
result.append(1 if all(x == 1 for x in chunk) else 0)
|
||||
|
||||
return result
|
||||
|
||||
def find_min_first_non_one_index(self, arr):
|
||||
try:
|
||||
return min(idx for row in arr for idx, val in enumerate(row)
|
||||
if val != 1)
|
||||
except ValueError:
|
||||
return -1
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the cache engine and free all the resources"""
|
||||
self.m_store.close()
|
||||
@@ -28,6 +28,7 @@ from vllm.forward_context import ForwardContext
|
||||
from vllm.utils import logger
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
@@ -100,7 +101,10 @@ class LLMDataDistCMgrConnectorMetadata(KVConnectorMetadata):
|
||||
|
||||
class LLMDataDistCMgrConnector(KVConnectorBase_V1):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: Optional[KVCacheConfig] = None):
|
||||
assert vllm_config.kv_transfer_config is not None
|
||||
self.engine_id = vllm_config.kv_transfer_config.engine_id
|
||||
if role == KVConnectorRole.SCHEDULER:
|
||||
|
||||
@@ -1,534 +0,0 @@
|
||||
import array
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import \
|
||||
KVConnectorMetadata
|
||||
from vllm.utils import logger
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.core.sched.output import NewRequestData
|
||||
|
||||
DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB
|
||||
DEFAULT_LOCAL_BUFFER_SIZE = 1073741824 # 1.0 GiB
|
||||
|
||||
|
||||
@dataclass
|
||||
class MooncakeEngineMetadata:
|
||||
"""name of the LLM model"""
|
||||
|
||||
model_name: str
|
||||
""" world size when running under a distributed setting """
|
||||
world_size: int
|
||||
""" worker id when running under a distributed setting """
|
||||
worker_id: int
|
||||
""" the format of kv tensors """
|
||||
kv_dtype: torch.dtype
|
||||
""" the shape of kv tensors """
|
||||
""" (num_layer, 2, metadata.block_size, num_kv_head, head_size) """
|
||||
kv_shape: tuple[int, int, int, int, int]
|
||||
block_size: int = 128
|
||||
""" whether use MLA"""
|
||||
use_mla: bool = False
|
||||
|
||||
|
||||
@dataclass(order=True)
|
||||
class MooncakeEngineKey:
|
||||
model_name: str
|
||||
world_size: int
|
||||
worker_id: int
|
||||
chunk_hash: str
|
||||
|
||||
def __hash__(self):
|
||||
return hash((
|
||||
self.model_name,
|
||||
self.world_size,
|
||||
self.worker_id,
|
||||
self.chunk_hash,
|
||||
))
|
||||
|
||||
def to_string(self):
|
||||
return (f"{self.model_name}@{self.world_size}"
|
||||
f"@{self.worker_id}@{self.chunk_hash}")
|
||||
|
||||
def split_layers(self, num_layers: int) -> List["LayerMooncakeEngineKey"]:
|
||||
"""Split the key into multiple keys for each layer"""
|
||||
keys = []
|
||||
for layer_id in range(num_layers):
|
||||
keys.append(
|
||||
LayerMooncakeEngineKey(
|
||||
self.model_name,
|
||||
self.world_size,
|
||||
self.worker_id,
|
||||
self.chunk_hash,
|
||||
layer_id,
|
||||
))
|
||||
return keys
|
||||
|
||||
def to_dict(self):
|
||||
# Note(Kuntai): this is used for serializing CacheEngineKey via msgpack.
|
||||
return {
|
||||
"__type__": "CacheEngineKey",
|
||||
"model_name": self.model_name,
|
||||
"world_size": self.world_size,
|
||||
"worker_id": self.worker_id,
|
||||
"chunk_hash": self.chunk_hash,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def from_dict(d):
|
||||
return MooncakeEngineKey(
|
||||
model_name=d["model_name"],
|
||||
world_size=d["world_size"],
|
||||
worker_id=d["worker_id"],
|
||||
chunk_hash=d["chunk_hash"],
|
||||
)
|
||||
|
||||
|
||||
@dataclass(order=True)
|
||||
class LayerMooncakeEngineKey(MooncakeEngineKey):
|
||||
"""A key for the layer cache engine"""
|
||||
|
||||
layer_id: int
|
||||
|
||||
def __hash__(self):
|
||||
return hash((
|
||||
self.model_name,
|
||||
self.world_size,
|
||||
self.worker_id,
|
||||
self.chunk_hash,
|
||||
self.layer_id,
|
||||
))
|
||||
|
||||
def to_string(self):
|
||||
return (f"{self.model_name}@{self.world_size}"
|
||||
f"@{self.worker_id}@{self.chunk_hash}@{self.layer_id}")
|
||||
|
||||
|
||||
class ChunkedTokenDatabase():
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
metadata: MooncakeEngineMetadata,
|
||||
):
|
||||
self.metadata = metadata
|
||||
|
||||
def _make_key_by_hash(self,
|
||||
chunk_hash: str,
|
||||
layer_id: Optional[int] = None):
|
||||
assert self.metadata is not None
|
||||
return MooncakeEngineKey(
|
||||
self.metadata.model_name,
|
||||
self.metadata.world_size,
|
||||
self.metadata.worker_id,
|
||||
chunk_hash,
|
||||
)
|
||||
|
||||
def _hash(
|
||||
self,
|
||||
tokens: Union[torch.Tensor, List[int]],
|
||||
prefix_hash: str,
|
||||
) -> str:
|
||||
# TODO: change it to a more efficient hash function
|
||||
if isinstance(tokens, torch.Tensor):
|
||||
tokens_bytes = tokens.cpu().to(torch.uint32).numpy().tobytes()
|
||||
elif isinstance(tokens, list):
|
||||
tokens_bytes = array.array("I", tokens).tobytes()
|
||||
return hashlib.sha256(prefix_hash.encode("ascii") +
|
||||
tokens_bytes).hexdigest()
|
||||
|
||||
def _chunk_tokens(
|
||||
self,
|
||||
tokens: Union[torch.Tensor, List[int]],
|
||||
) -> Iterable[Union[torch.Tensor, List[int]]]:
|
||||
"""
|
||||
Chunk the tokens into chunks of size self.metadata.block_size.
|
||||
|
||||
:param tokens: the input tokens, with shape [seq_len]
|
||||
device: the target device after chunking
|
||||
|
||||
:return: a generator of chunks of tokens, each with
|
||||
shape [metadata.block_size]
|
||||
"""
|
||||
for i in range(0, len(tokens), self.metadata.block_size):
|
||||
yield tokens[i:i + self.metadata.block_size]
|
||||
|
||||
def _prefix_hash(
|
||||
self,
|
||||
token_chunks: Iterable[Union[torch.Tensor, List[int]]],
|
||||
) -> Iterable[str]:
|
||||
prefix_hash = ''
|
||||
for token_chunk in token_chunks:
|
||||
prefix_hash = self._hash(token_chunk, prefix_hash)
|
||||
yield prefix_hash
|
||||
|
||||
def process_tokens(
|
||||
self,
|
||||
tokens: Union[torch.Tensor, List[int]],
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
) -> Iterable[Tuple[int, int, MooncakeEngineKey]]:
|
||||
"""Process the tokens and return the corresponding cache engine keys.
|
||||
|
||||
:param Union[torch.Tensor, List[int]] tokens: The tokens to process.
|
||||
|
||||
:param Optional[torch.Tensor] mask: The mask for the tokens. Should
|
||||
have the same length as tokens. And the mask should ALWAYS be like
|
||||
FFFFFTTTTTTT, where True means the tokens needs to be matched,
|
||||
and the Falses will ALWAYS be at the PREFIX of the tensor.
|
||||
|
||||
:param bool make_key: Whether to make the cache engine key or not.
|
||||
If False, the hash value will be returned instead.
|
||||
|
||||
:returns: A iterable of tuples with three elements. The first element
|
||||
is the start index of the tokens for the key. The second element
|
||||
is the end index of the tokens for the key. The third element is
|
||||
the cache engine key (or hash) for the tokens.
|
||||
|
||||
:raises: ValueError if the number of Falses in the mask is not a
|
||||
multiple of the chunk size.
|
||||
"""
|
||||
if mask is not None:
|
||||
num_falses = mask.numel() - mask.long().sum().item()
|
||||
else:
|
||||
num_falses = 0
|
||||
|
||||
if num_falses % self.metadata.block_size != 0:
|
||||
raise ValueError(
|
||||
"The number of Falses in the mask is not a multiple of the chunk size."
|
||||
)
|
||||
total_len = len(tokens)
|
||||
|
||||
token_chunks = self._chunk_tokens(tokens)
|
||||
prefix_hashes = self._prefix_hash(token_chunks)
|
||||
|
||||
start_idx = 0
|
||||
for chunk_id, hash_val in enumerate(prefix_hashes):
|
||||
start_idx = chunk_id * self.metadata.block_size
|
||||
end_idx = min(start_idx + self.metadata.block_size, total_len)
|
||||
if start_idx < num_falses:
|
||||
continue
|
||||
else:
|
||||
yield start_idx, end_idx, self._make_key_by_hash(hash_val)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadSpec:
|
||||
# Number of tokens cached in vLLM
|
||||
vllm_cached_tokens: int
|
||||
# Number of tokens that are cached in mooncake
|
||||
mooncake_cached_tokens: int
|
||||
# Whether the scheduler allow us to load the tokens
|
||||
can_load: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class SaveSpec:
|
||||
# Skip already saved tokens
|
||||
skip_leading_tokens: int
|
||||
# Whether the scheduler allow us to save the tokens
|
||||
can_save: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestTracker:
|
||||
# Request id
|
||||
req_id: str
|
||||
|
||||
# The token ids that has been scheduled so far
|
||||
token_ids: list[int]
|
||||
|
||||
# The block ids that has been allocated so far
|
||||
# NOTE: allocated blocks could be more than the number of tokens
|
||||
# FIXME: need to check whether the block ids will be changed after
|
||||
# preemption
|
||||
allocated_block_ids: list[int]
|
||||
|
||||
# The number of tokens that has been savd
|
||||
num_saved_tokens: int = 0
|
||||
|
||||
@staticmethod
|
||||
def from_new_request(
|
||||
new_request: "NewRequestData",
|
||||
num_tokens_to_compute: int,
|
||||
) -> "RequestTracker":
|
||||
"""Create the request tracker from a new request.
|
||||
|
||||
Args:
|
||||
new_request (NewRequestData): the new request data.
|
||||
num_tokens_to_compute (int): the number of tokens that will
|
||||
be 'computed', including the `num_computed_tokens` (vLLM's
|
||||
local cache hit) and new tokens that will be scheduled.
|
||||
|
||||
"""
|
||||
# vLLM 0.9.0 update: request.block_ids changed from list[int] to
|
||||
# list[list[int]]
|
||||
# Need to check the type of request.block_ids
|
||||
|
||||
unfolded_block_ids = []
|
||||
|
||||
if not isinstance(new_request.block_ids[0], list):
|
||||
unfolded_block_ids = new_request.block_ids.copy()
|
||||
else:
|
||||
unfolded_block_ids = new_request.block_ids[0].copy()
|
||||
|
||||
return RequestTracker(
|
||||
req_id=new_request.req_id,
|
||||
token_ids=new_request.prompt_token_ids[:num_tokens_to_compute].
|
||||
copy(),
|
||||
allocated_block_ids=unfolded_block_ids,
|
||||
num_saved_tokens=0,
|
||||
)
|
||||
|
||||
def update(
|
||||
self,
|
||||
new_token_ids: list[int],
|
||||
new_block_ids: Union[tuple[list[int], ...], list[int]],
|
||||
) -> None:
|
||||
"""Update the request tracker when a running request is
|
||||
scheduled again
|
||||
"""
|
||||
|
||||
self.token_ids.extend(new_token_ids)
|
||||
|
||||
if len(new_block_ids) == 0:
|
||||
new_block_ids = []
|
||||
elif isinstance(new_block_ids, tuple):
|
||||
new_block_ids = new_block_ids[0]
|
||||
elif isinstance(new_block_ids, list):
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported new_block_ids type {type(new_block_ids)}")
|
||||
self.allocated_block_ids.extend(new_block_ids)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReqMeta:
|
||||
# Request id
|
||||
req_id: str
|
||||
# Request tokens
|
||||
token_ids: torch.Tensor
|
||||
|
||||
block_ids: list[int]
|
||||
# # Slot mapping if exchange for block_id
|
||||
# slot_mapping: torch.Tensor
|
||||
# Skip save or not
|
||||
save_spec: Optional[SaveSpec] = None
|
||||
# load_spec
|
||||
load_spec: Optional[LoadSpec] = None
|
||||
|
||||
is_last_chunk: Optional[bool] = None
|
||||
|
||||
@staticmethod
|
||||
def from_request_tracker(
|
||||
tracker: RequestTracker,
|
||||
block_size: int,
|
||||
load_spec: Optional[LoadSpec] = None,
|
||||
skip_save: Optional[bool] = False,
|
||||
is_last_chunk: Optional[bool] = None,
|
||||
discard_partial_chunks: bool = True,
|
||||
) -> Optional["ReqMeta"]:
|
||||
"""Create the request metadata from a request tracker.
|
||||
|
||||
Args:
|
||||
tracker (RequestTracker): the request tracker.
|
||||
block_size (int): the block size in vLLM.
|
||||
load_spec (Optional[LoadSpec]): the load spec for KV cache loading.
|
||||
skip_save (bool): whether to skip the save operation.
|
||||
discard_partial_chunks (bool): whether to discard partial chunks.
|
||||
|
||||
Returns:
|
||||
the request metadata if we need to perform load/save
|
||||
operations, None otherwise.
|
||||
"""
|
||||
input_token_ids = tracker.token_ids
|
||||
input_token_len = len(input_token_ids)
|
||||
|
||||
# For save operation: do not save if the following condition is met
|
||||
# 1. has already been saved before (num_saved_tokens > 0)
|
||||
# 2. number of unsaved tokens is not reached the chunk boundary
|
||||
skip_leading_tokens = tracker.num_saved_tokens
|
||||
chunk_boundary = (cdiv(tracker.num_saved_tokens + 1, block_size) *
|
||||
block_size if discard_partial_chunks else 0)
|
||||
# Calculate number of tokens to save based on discard_partial_chunks
|
||||
# setting
|
||||
num_tokens_to_save = ((input_token_len // block_size * block_size)
|
||||
if discard_partial_chunks else input_token_len)
|
||||
|
||||
skip_save = skip_save or num_tokens_to_save < chunk_boundary
|
||||
if skip_save and load_spec is None:
|
||||
return None
|
||||
|
||||
# If we need to save, update the number of saved tokens
|
||||
if not skip_save:
|
||||
tracker.num_saved_tokens = num_tokens_to_save
|
||||
save_spec = SaveSpec(skip_leading_tokens, not skip_save)
|
||||
|
||||
# Calculate the token ids and slot mappings for load and save
|
||||
# OPTIMIZATION: pre-allocate the buffer for token ids and block ids
|
||||
token_ids = torch.tensor(input_token_ids)[:num_tokens_to_save]
|
||||
|
||||
# # For load operation: check whether the request is scheduled to load
|
||||
if load_spec is not None and load_spec.can_load:
|
||||
logger.debug(
|
||||
"Scheduled to load %d tokens for request %s",
|
||||
load_spec.mooncake_cached_tokens,
|
||||
tracker.req_id,
|
||||
)
|
||||
else:
|
||||
# Do not load if not in `can_load` state
|
||||
load_spec = None
|
||||
logger.debug(
|
||||
f"request:{tracker.req_id}, meta save spec:{save_spec}, meta load spec:{load_spec}"
|
||||
)
|
||||
return ReqMeta(
|
||||
req_id=tracker.req_id,
|
||||
token_ids=token_ids,
|
||||
block_ids=tracker.allocated_block_ids,
|
||||
save_spec=save_spec,
|
||||
load_spec=load_spec,
|
||||
is_last_chunk=is_last_chunk,
|
||||
)
|
||||
|
||||
|
||||
class MooncakeConnectorMetadata(KVConnectorMetadata):
|
||||
|
||||
def __init__(self, unfinished_request_ids):
|
||||
self.requests = []
|
||||
self.unfinished_request_ids = unfinished_request_ids
|
||||
|
||||
def add_request(self, req_meta: ReqMeta) -> None:
|
||||
"""Add a request to the metadata.
|
||||
|
||||
Args:
|
||||
req_meta (ReqMeta): the request metadata.
|
||||
"""
|
||||
self.requests.append(req_meta)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LasyerMultiBlockReqMeta:
|
||||
req_id: str
|
||||
keys: List[LayerMooncakeEngineKey]
|
||||
starts: List[int]
|
||||
ends: list[int]
|
||||
block_ids: list[int]
|
||||
layer_id: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class MooncakeStoreConfig:
|
||||
local_hostname: str
|
||||
metadata_server: str
|
||||
global_segment_size: Union[int, str]
|
||||
local_buffer_size: int
|
||||
protocol: str
|
||||
device_name: str
|
||||
master_server_address: str
|
||||
use_ascend_direct: bool
|
||||
|
||||
@staticmethod
|
||||
def from_file(file_path: str) -> "MooncakeStoreConfig":
|
||||
with open(file_path) as file:
|
||||
config = json.load(file)
|
||||
return MooncakeStoreConfig(
|
||||
local_hostname=config.get("local_hostname"),
|
||||
metadata_server=config.get("metadata_server"),
|
||||
global_segment_size=_parse_global_segment_size(
|
||||
config.get("global_segment_size",
|
||||
DEFAULT_GLOBAL_SEGMENT_SIZE)),
|
||||
local_buffer_size=(config.get("local_buffer_size",
|
||||
DEFAULT_LOCAL_BUFFER_SIZE)),
|
||||
protocol=config.get("protocol", "tcp"),
|
||||
device_name=config.get("device_name", ""),
|
||||
master_server_address=config.get("master_server_address"),
|
||||
use_ascend_direct=config.get("use_ascend_direct", False))
|
||||
|
||||
@staticmethod
|
||||
def load_from_env() -> "MooncakeStoreConfig":
|
||||
config_path = os.getenv("MOONCAKE_CONFIG_PATH")
|
||||
if not config_path:
|
||||
raise ValueError(
|
||||
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set.")
|
||||
return MooncakeStoreConfig.from_file(config_path)
|
||||
|
||||
|
||||
def _parse_global_segment_size(value) -> int:
|
||||
"""
|
||||
Parse storage size strings with support for units: GB, MB, KB, B
|
||||
|
||||
Args:
|
||||
value: Input value (int, str, or other convertible types)
|
||||
|
||||
Returns:
|
||||
int: Size in bytes
|
||||
|
||||
Raises:
|
||||
ValueError: For invalid format, missing number, or negative values
|
||||
TypeError: For unsupported input types
|
||||
"""
|
||||
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
elif not isinstance(value, str):
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError) as e:
|
||||
raise TypeError(
|
||||
f"Unsupported type for global_segment_size: {type(value)}"
|
||||
) from e
|
||||
|
||||
cleaned_input = value.strip().lower()
|
||||
if not cleaned_input:
|
||||
raise ValueError("global segment size cannot be empty.")
|
||||
|
||||
UNIT_MULTIPLIERS = {
|
||||
'gb': 1024**3, # 1 GB = 1024^3 bytes
|
||||
'mb': 1024**2, # 1 MB = 1024^2 bytes
|
||||
'kb': 1024, # 1 KB = 1024 bytes
|
||||
'b': 1 # 1 B = 1 byte
|
||||
}
|
||||
pattern = r'^\s*([\d.]+)\s*(gb|mb|kb|b)?\s*$'
|
||||
match = re.match(pattern, cleaned_input)
|
||||
|
||||
if not match:
|
||||
raise ValueError(f"Invalid format: '{value}'")
|
||||
|
||||
number_str = match.group(1)
|
||||
unit = match.group(2) or 'b'
|
||||
|
||||
multiplier = UNIT_MULTIPLIERS[unit]
|
||||
return _convert_to_bytes(number_str, multiplier, value)
|
||||
|
||||
|
||||
def _convert_to_bytes(number_str: str, multiplier: int,
|
||||
original_input: str) -> int:
|
||||
"""
|
||||
Convert numeric string to byte count
|
||||
|
||||
Args:
|
||||
number_str: Numeric portion of input
|
||||
multiplier: Unit conversion factor
|
||||
original_input: Original input string (for error messages)
|
||||
|
||||
Returns:
|
||||
int: Byte count
|
||||
|
||||
Raises:
|
||||
ValueError: For invalid numbers or negative results
|
||||
"""
|
||||
try:
|
||||
numeric_value = float(number_str)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Invalid numeric value '{number_str}' in: '{original_input}'")
|
||||
# Calculate byte count
|
||||
try:
|
||||
byte_count = int(numeric_value * multiplier)
|
||||
except OverflowError:
|
||||
raise ValueError(f"Storage size too large: '{original_input}'")
|
||||
return byte_count
|
||||
@@ -1,282 +0,0 @@
|
||||
import queue
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from vllm.utils import logger
|
||||
|
||||
from vllm_ascend.distributed.mooncake.config_data import (
|
||||
ChunkedTokenDatabase, LasyerMultiBlockReqMeta)
|
||||
from vllm_ascend.distributed.mooncake.mooncake_store import Mooncakestore
|
||||
|
||||
|
||||
class KVTransferThread(threading.Thread):
|
||||
|
||||
def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore,
|
||||
local_kv_caches_base_addr: list[int],
|
||||
token_database: ChunkedTokenDatabase, block_len: list[int],
|
||||
block_size: int, ready_event: threading.Event, name: str):
|
||||
super().__init__(daemon=True, name=name)
|
||||
self.tp_rank = tp_rank
|
||||
self.tp_size = tp_size
|
||||
self.m_store = m_store
|
||||
self.ready_event = ready_event
|
||||
self.kv_caches_base_addr = local_kv_caches_base_addr
|
||||
self.block_len = block_len
|
||||
self.token_database = token_database
|
||||
self.block_size = block_size
|
||||
self.done_task_lock = threading.Lock()
|
||||
# TODO(jianzs): find a better way to detect MLA.
|
||||
self.use_mla = len(block_len) == 2
|
||||
|
||||
self.request_queue: queue.Queue[Any] = queue.Queue()
|
||||
# TODO(jianzs): make this configurable
|
||||
self.executor = ThreadPoolExecutor(max_workers=32)
|
||||
self.finished_requests: set[str] = set()
|
||||
|
||||
def prepare_value(self, start: int, end: int, block_ids: list[int]):
|
||||
addr_list = []
|
||||
size_list = []
|
||||
block_id = block_ids[start // self.block_size]
|
||||
for index, base_addr in enumerate(self.kv_caches_base_addr):
|
||||
block_len = (self.block_len[index % 2]
|
||||
if self.use_mla else self.block_len[0])
|
||||
|
||||
addr = base_addr + block_id * block_len
|
||||
length = int(block_len / self.block_size * (end - start))
|
||||
addr_list.append(addr)
|
||||
size_list.append(length)
|
||||
return addr_list, size_list, block_id
|
||||
|
||||
def prepare_value_layer(self, start: int, end: int, block_ids: list[int],
|
||||
layer_id: int):
|
||||
block_id = block_ids[start // self.block_size]
|
||||
if self.use_mla:
|
||||
addr_k = self.kv_caches_base_addr[layer_id *
|
||||
2] + block_id * self.block_len[0]
|
||||
addr_v = self.kv_caches_base_addr[layer_id * 2 +
|
||||
1] + block_id * self.block_len[1]
|
||||
length_k = int(self.block_len[0] / self.block_size * (end - start))
|
||||
length_v = int(self.block_len[1] / self.block_size * (end - start))
|
||||
size_list = [length_k, length_v]
|
||||
else:
|
||||
addr_k = self.kv_caches_base_addr[layer_id *
|
||||
2] + block_id * self.block_len[0]
|
||||
addr_v = self.kv_caches_base_addr[layer_id * 2 +
|
||||
1] + block_id * self.block_len[0]
|
||||
length = int(self.block_len[0] / self.block_size * (end - start))
|
||||
size_list = [length, length]
|
||||
addr_list = [addr_k, addr_v]
|
||||
return addr_list, size_list
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
req_id: str,
|
||||
tokens: torch.Tensor,
|
||||
block_ids: list[int],
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
is_last_chunk: Optional[bool] = None,
|
||||
) -> torch.Tensor:
|
||||
req = ({
|
||||
"req_id": req_id,
|
||||
"tokens": tokens,
|
||||
"block_ids": block_ids,
|
||||
"mask": mask,
|
||||
"is_last_chunk": is_last_chunk,
|
||||
})
|
||||
self.request_queue.put(req)
|
||||
|
||||
def get_and_clear_finished_requests(self) -> set[str]:
|
||||
"""
|
||||
Get and clear the requests that have been completed.
|
||||
Returns:
|
||||
A set of request IDs that have been completed.
|
||||
"""
|
||||
with self.done_task_lock:
|
||||
finished_requests = self.finished_requests.copy()
|
||||
self.finished_requests.clear()
|
||||
return finished_requests
|
||||
|
||||
def set_finished_request(self, req_id):
|
||||
with self.done_task_lock:
|
||||
self.finished_requests.add(req_id)
|
||||
|
||||
def run(self):
|
||||
"""Run the thread to handle KV cache transfer requests."""
|
||||
self.ready_event.set()
|
||||
while True:
|
||||
try:
|
||||
request_data = self.request_queue.get()
|
||||
if request_data is None:
|
||||
logger.warning("Received a None request!")
|
||||
self.request_queue.task_done()
|
||||
continue
|
||||
self._handle_request(request_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in KVCacheTransferThread: {e}")
|
||||
|
||||
def _handle_request(self, req_meta: dict[str, Any]):
|
||||
pass
|
||||
|
||||
|
||||
class KVCacheStoreSendingThread(KVTransferThread):
|
||||
|
||||
def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore,
|
||||
local_kv_caches_base_addr: list[int],
|
||||
token_database: ChunkedTokenDatabase, block_len: list[int],
|
||||
block_size: int, ready_event: threading.Event):
|
||||
super().__init__(tp_rank,
|
||||
tp_size,
|
||||
m_store,
|
||||
local_kv_caches_base_addr,
|
||||
token_database,
|
||||
block_len,
|
||||
block_size,
|
||||
ready_event,
|
||||
name="KVCacheSendingThread")
|
||||
|
||||
def _handle_request(self, req_meta: dict[str, Any]):
|
||||
tokens = req_meta["tokens"]
|
||||
mask = req_meta["mask"]
|
||||
block_ids = req_meta["block_ids"]
|
||||
req_id = req_meta["req_id"]
|
||||
is_last_chunk = req_meta["is_last_chunk"]
|
||||
if self.m_store.config.use_ascend_direct:
|
||||
addr_list = []
|
||||
size_list = []
|
||||
key_list = []
|
||||
blockIds = []
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens, mask):
|
||||
addr, size, block_id = self.prepare_value(
|
||||
start, end, block_ids)
|
||||
key_list.append(key.to_string())
|
||||
addr_list.append(addr)
|
||||
size_list.append(size)
|
||||
blockIds.append(block_id)
|
||||
torch.npu.current_stream().synchronize()
|
||||
self.m_store.put_batch(key_list, addr_list, size_list, blockIds)
|
||||
else:
|
||||
torch.npu.current_stream().synchronize()
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens, mask):
|
||||
addr, size, _ = self.prepare_value(start, end, block_ids)
|
||||
self.m_store.put(key, addr, size)
|
||||
if is_last_chunk:
|
||||
self.set_finished_request(req_id)
|
||||
self.request_queue.task_done()
|
||||
|
||||
|
||||
class KVCacheStoreRecvingThread(KVTransferThread):
|
||||
|
||||
def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore,
|
||||
local_kv_caches_base_addr: list[int],
|
||||
token_database: ChunkedTokenDatabase, block_len: list[int],
|
||||
block_size: int, ready_event: threading.Event):
|
||||
super().__init__(tp_rank,
|
||||
tp_size,
|
||||
m_store,
|
||||
local_kv_caches_base_addr,
|
||||
token_database,
|
||||
block_len,
|
||||
block_size,
|
||||
ready_event,
|
||||
name="KVCacheStoreRecvingThread")
|
||||
|
||||
def _handle_request(self, req_meta: dict[str, Any]):
|
||||
tokens = req_meta["tokens"]
|
||||
mask = req_meta["mask"]
|
||||
block_ids = req_meta["block_ids"]
|
||||
req_id = req_meta["req_id"]
|
||||
if self.m_store.config.use_ascend_direct:
|
||||
addr_list = []
|
||||
size_list = []
|
||||
key_list = []
|
||||
blockIds = []
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens, mask):
|
||||
addr, size, block_id = self.prepare_value(
|
||||
start, end, block_ids)
|
||||
key_list.append(key.to_string())
|
||||
addr_list.append(addr)
|
||||
size_list.append(size)
|
||||
blockIds.append(block_id)
|
||||
self.m_store.get_batch(key_list, addr_list, size_list, blockIds)
|
||||
else:
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
tokens, mask):
|
||||
addr, size, _ = self.prepare_value(start, end, block_ids)
|
||||
self.m_store.get(key, addr, size)
|
||||
self.set_finished_request(req_id)
|
||||
self.request_queue.task_done()
|
||||
|
||||
|
||||
class KVCacheStoreLayerSendingThread(KVTransferThread):
|
||||
|
||||
def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore,
|
||||
local_kv_caches_base_addr: list[int],
|
||||
token_database: ChunkedTokenDatabase, block_len: list[int],
|
||||
block_size: int, ready_event: threading.Event,
|
||||
num_layers: int):
|
||||
super().__init__(tp_rank,
|
||||
tp_size,
|
||||
m_store,
|
||||
local_kv_caches_base_addr,
|
||||
token_database,
|
||||
block_len,
|
||||
block_size,
|
||||
ready_event,
|
||||
name="KVCacheStoreLayerSendingThread")
|
||||
self.final_layer_id = num_layers - 1
|
||||
|
||||
def add_request( # type: ignore[override]
|
||||
self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor:
|
||||
self.request_queue.put(req_meta)
|
||||
|
||||
def _handle_request( # type: ignore[override]
|
||||
self, req_meta: LasyerMultiBlockReqMeta):
|
||||
torch.npu.current_stream().synchronize()
|
||||
for index, key in enumerate(req_meta.keys):
|
||||
addr, size = self.prepare_value_layer(req_meta.starts[index],
|
||||
req_meta.ends[index],
|
||||
req_meta.block_ids,
|
||||
req_meta.layer_id)
|
||||
self.m_store.put(key, addr, size)
|
||||
if req_meta.layer_id == self.final_layer_id:
|
||||
self.set_finished_request(req_meta.req_id)
|
||||
self.request_queue.task_done()
|
||||
|
||||
|
||||
class KVCacheStoreLayerRecvingThread(KVTransferThread):
|
||||
|
||||
def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore,
|
||||
local_kv_caches_base_addr: list[int],
|
||||
token_database: ChunkedTokenDatabase, block_len: list[int],
|
||||
block_size: int, ready_event: threading.Event,
|
||||
get_event: threading.Event):
|
||||
super().__init__(tp_rank,
|
||||
tp_size,
|
||||
m_store,
|
||||
local_kv_caches_base_addr,
|
||||
token_database,
|
||||
block_len,
|
||||
block_size,
|
||||
ready_event,
|
||||
name="KVCacheStoreLayerRecvingThread")
|
||||
self.get_event = get_event
|
||||
|
||||
def add_request( # type: ignore[override]
|
||||
self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor:
|
||||
self.request_queue.put(req_meta)
|
||||
|
||||
def _handle_request( # type: ignore[override]
|
||||
self, req_meta: LasyerMultiBlockReqMeta):
|
||||
for index, key in enumerate(req_meta.keys):
|
||||
addr, size = self.prepare_value_layer(req_meta.starts[index],
|
||||
req_meta.ends[index],
|
||||
req_meta.block_ids,
|
||||
req_meta.layer_id)
|
||||
self.m_store.get(key, addr, size)
|
||||
self.request_queue.task_done()
|
||||
self.get_event.set()
|
||||
@@ -1,127 +0,0 @@
|
||||
# Standard
|
||||
import os
|
||||
|
||||
# Third Party
|
||||
from mooncake.store import ReplicateConfig # type: ignore
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.distributed.parallel_state import get_tensor_model_parallel_rank
|
||||
from vllm.utils import logger
|
||||
from vllm.utils.network_utils import get_ip
|
||||
|
||||
from vllm_ascend.distributed.mooncake.config_data import MooncakeEngineKey
|
||||
from vllm_ascend.distributed.mooncake.transfer_engine import get_global_te
|
||||
|
||||
from .config_data import MooncakeStoreConfig
|
||||
|
||||
METADATA_BYTES_LEN = 24
|
||||
BASE_PORT = int(os.getenv("VLLM_BASE_PORT", "8790"))
|
||||
|
||||
|
||||
class Mooncakestore():
|
||||
|
||||
def __init__(self, parallel_config: ParallelConfig):
|
||||
try:
|
||||
from mooncake.store import MooncakeDistributedStore # type: ignore
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Please install mooncake by following the instructions at "
|
||||
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
|
||||
"to run vLLM with MooncakeConnector.") from e
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
tp_size = parallel_config.tensor_parallel_size
|
||||
dp_rank = parallel_config.data_parallel_rank_local
|
||||
all_device_ids = os.getenv("ASCEND_RT_VISIBLE_DEVICES", None)
|
||||
if not all_device_ids:
|
||||
device_ids_list = list(
|
||||
range(dp_rank * tp_size, (dp_rank + 1) * tp_size))
|
||||
else:
|
||||
device_ids_list = list(map(int, all_device_ids.split(',')))
|
||||
assert len(device_ids_list) > tp_rank
|
||||
device_id = device_ids_list[tp_rank]
|
||||
self.config = MooncakeStoreConfig.load_from_env()
|
||||
self.store = MooncakeDistributedStore()
|
||||
if self.config.protocol == "ascend" and not self.config.use_ascend_direct:
|
||||
local_hostname = get_ip() + ":" + str(BASE_PORT + int(device_id)) + \
|
||||
":npu_" + str(device_id)
|
||||
ret = self.store.setup(local_hostname, self.config.metadata_server,
|
||||
self.config.global_segment_size,
|
||||
self.config.local_buffer_size,
|
||||
self.config.protocol,
|
||||
self.config.device_name,
|
||||
self.config.master_server_address)
|
||||
else:
|
||||
local_hostname = get_ip()
|
||||
transfer_engine = get_global_te(local_hostname, device_name=None)
|
||||
self.local_seg = local_hostname + ":" + str(
|
||||
transfer_engine.get_rpc_port())
|
||||
ret = self.store.setup(self.local_seg, self.config.metadata_server,
|
||||
self.config.global_segment_size,
|
||||
self.config.local_buffer_size,
|
||||
self.config.protocol,
|
||||
self.config.device_name,
|
||||
self.config.master_server_address,
|
||||
transfer_engine.get_engine())
|
||||
if ret != 0:
|
||||
msg = "Initialize mooncake failed."
|
||||
logger.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
|
||||
def exists(self, key: MooncakeEngineKey) -> bool:
|
||||
return self.store.is_exist(key.to_string()) == 1
|
||||
|
||||
def batch_exists(self, keys: list[str]) -> list[int]:
|
||||
return self.store.batch_is_exist(keys)
|
||||
|
||||
def register_buffer(self, ptr, length):
|
||||
return self.store.register_buffer(ptr, length)
|
||||
|
||||
def get_batch(self, keys: list[str], addrs: list[list[int]],
|
||||
sizes: list[list[int]], block_ids: list[int]):
|
||||
try:
|
||||
res = self.store.batch_get_into_multi_buffers(
|
||||
keys, addrs, sizes, True)
|
||||
for value in res:
|
||||
if value < 0:
|
||||
logger.error(f"Failed to get key {keys},res:{res}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get key {keys}. {e}")
|
||||
|
||||
def put_batch(self, keys: list[str], addrs: list[list[int]],
|
||||
sizes: list[list[int]], block_ids: list[int]):
|
||||
try:
|
||||
config = ReplicateConfig()
|
||||
config.preferred_segment = self.local_seg
|
||||
config.prefer_alloc_in_same_node = True
|
||||
res = self.store.batch_put_from_multi_buffers(
|
||||
keys, addrs, sizes, config)
|
||||
for value in res:
|
||||
if value < 0:
|
||||
logger.error(f"Failed to put key {keys},res:{res}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to put key {keys},error:{e}")
|
||||
|
||||
def get(self, key: MooncakeEngineKey, addr: list[int], size: list[int]):
|
||||
expect_res = sum(size)
|
||||
key_str = key.to_string()
|
||||
try:
|
||||
res = self.store.batch_get_into_ascend(key_str, addr, size)
|
||||
if res[0] != expect_res:
|
||||
logger.error(f"Failed to get key: [{key_str}] .")
|
||||
except Exception:
|
||||
logger.error(f"Failed to get key: [{key_str}] .")
|
||||
return res
|
||||
|
||||
def put(self, key: MooncakeEngineKey, addr: list[int], size: list[int]):
|
||||
key_str = key.to_string()
|
||||
try:
|
||||
ret = self.store.batch_put_from_ascend(key_str, addr, size)
|
||||
if ret[0] != 0:
|
||||
logger.error(f"Failed to put key {key_str}.")
|
||||
except Exception:
|
||||
logger.error(f"Failed to put key {key_str}.")
|
||||
|
||||
return ret
|
||||
|
||||
def close(self):
|
||||
self.store.close()
|
||||
logger.info("Closed the mooncake store connection")
|
||||
@@ -1,38 +0,0 @@
|
||||
import ipaddress
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
from mooncake.engine import TransferEngine # type: ignore
|
||||
|
||||
_global_te = None
|
||||
_global_te_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_global_te(hostname: str, device_name: Optional[str]):
|
||||
try:
|
||||
ip = ipaddress.ip_address(hostname)
|
||||
if isinstance(ip, ipaddress.IPv6Address):
|
||||
raise RuntimeError(
|
||||
"The backend of mooncake's Ascend Direct Xfer Library currently does not support IPv6."
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
global _global_te
|
||||
if _global_te is None:
|
||||
with _global_te_lock:
|
||||
# Double-Checked Locking
|
||||
if _global_te is None:
|
||||
if TransferEngine is None:
|
||||
raise RuntimeError("mooncake is not available")
|
||||
transfer_engine = TransferEngine()
|
||||
device_name = device_name if device_name is not None else ""
|
||||
ret_value = transfer_engine.initialize(hostname,
|
||||
"P2PHANDSHAKE",
|
||||
"ascend", device_name)
|
||||
if ret_value != 0:
|
||||
raise RuntimeError(
|
||||
f"TransferEngine initialization failed with ret_value: {ret_value}"
|
||||
)
|
||||
_global_te = transfer_engine
|
||||
return _global_te
|
||||
@@ -31,11 +31,12 @@ from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tp_group)
|
||||
from vllm.utils import logger
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import RequestStatus
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config
|
||||
from vllm_ascend.distributed.mooncake.transfer_engine import get_global_te
|
||||
from vllm_ascend.distributed.mooncake_transfer_engine import global_te
|
||||
from vllm_ascend.distributed.utils import get_transfer_timeout_value
|
||||
from vllm_ascend.utils import prefill_context_parallel_enable
|
||||
|
||||
@@ -634,7 +635,10 @@ class MooncakeConnectorMetadata(KVConnectorMetadata):
|
||||
|
||||
class MooncakeConnector(KVConnectorBase_V1):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: Optional[KVCacheConfig] = None):
|
||||
assert vllm_config.kv_transfer_config is not None
|
||||
self.engine_id = vllm_config.kv_transfer_config.engine_id
|
||||
|
||||
@@ -944,7 +948,7 @@ class MooncakeConnectorWorker:
|
||||
else:
|
||||
hostname = f"{self.side_channel_host}:0:npu_{self.device_id}"
|
||||
logger.info("Initializing Mooncake work %s", engine_id)
|
||||
self.engine = get_global_te(hostname, device_name=None)
|
||||
self.engine = global_te.get_transfer_engine(hostname, device_name=None)
|
||||
self.te_rpc_port = self.engine.get_rpc_port()
|
||||
|
||||
# Background thread for sending or receiving KV caches.
|
||||
@@ -1054,6 +1058,8 @@ class MooncakeConnectorWorker:
|
||||
|
||||
self.kv_caches = kv_caches
|
||||
kv_caches_base_addr = []
|
||||
ptrs = []
|
||||
lengths = []
|
||||
for cache_or_caches in kv_caches.values():
|
||||
# Normalize to always be a list of caches
|
||||
if self.use_mla:
|
||||
@@ -1061,13 +1067,15 @@ class MooncakeConnectorWorker:
|
||||
base_addr = cache.data_ptr()
|
||||
region_len = self.num_blocks * self.block_len[i % 2]
|
||||
kv_caches_base_addr.append(base_addr)
|
||||
self._register(base_addr, region_len)
|
||||
ptrs.append(base_addr)
|
||||
lengths.append(region_len)
|
||||
elif self.use_sparse:
|
||||
for i, cache in enumerate(cache_or_caches, 0):
|
||||
base_addr = cache.data_ptr()
|
||||
region_len = self.num_blocks * self.block_len[i % 3]
|
||||
kv_caches_base_addr.append(base_addr)
|
||||
self._register(base_addr, region_len)
|
||||
ptrs.append(base_addr)
|
||||
lengths.append(region_len)
|
||||
else:
|
||||
cache_list = [
|
||||
cache_or_caches
|
||||
@@ -1076,8 +1084,9 @@ class MooncakeConnectorWorker:
|
||||
base_addr = cache.data_ptr()
|
||||
region_len = self.num_blocks * self.block_len[0]
|
||||
kv_caches_base_addr.append(base_addr)
|
||||
self._register(base_addr, region_len)
|
||||
|
||||
ptrs.append(base_addr)
|
||||
lengths.append(region_len)
|
||||
global_te.register_buffer(ptrs, lengths)
|
||||
# After KV Caches registered, start the sending or receiving thread.
|
||||
metadata = MooncakeAgentMetadata(
|
||||
engine_id=self.engine_id,
|
||||
@@ -1101,14 +1110,6 @@ class MooncakeConnectorWorker:
|
||||
self.kv_recv_thread.start()
|
||||
ready_event.wait()
|
||||
|
||||
def _register(self, ptr, length):
|
||||
logger.debug(
|
||||
"Registering KV cache: ptr=0x%x, length=%d, num_blocks=%d, "
|
||||
"block_lens=%s", ptr, length, self.num_blocks, self.block_len)
|
||||
ret_value = self.engine.register_memory(ptr, length)
|
||||
if ret_value != 0:
|
||||
raise RuntimeError("Mooncake memory registration failed.")
|
||||
|
||||
def get_finished(self) -> tuple[set[str], set[str]]:
|
||||
done_sending = (
|
||||
self.kv_send_thread.
|
||||
|
||||
@@ -30,6 +30,7 @@ from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank,
|
||||
from vllm.utils import logger
|
||||
from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
@@ -359,7 +360,10 @@ class MooncakeLayerwiseConnectorMetadata(KVConnectorMetadata):
|
||||
|
||||
class MooncakeLayerwiseConnector(KVConnectorBase_V1):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: Optional[KVCacheConfig] = None):
|
||||
assert vllm_config.kv_transfer_config is not None
|
||||
self.engine_id = vllm_config.kv_transfer_config.engine_id
|
||||
self._connector_metadata = MooncakeLayerwiseConnectorMetadata()
|
||||
|
||||
53
vllm_ascend/distributed/mooncake_transfer_engine.py
Normal file
53
vllm_ascend/distributed/mooncake_transfer_engine.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import ipaddress
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
from mooncake.engine import TransferEngine # type: ignore
|
||||
|
||||
|
||||
class GlobalTE():
|
||||
|
||||
def __init__(self):
|
||||
self.transfer_engine = None
|
||||
self.is_register_buffer: bool = False
|
||||
self.transfer_engine_lock = threading.Lock()
|
||||
self.register_buffer_lock = threading.Lock()
|
||||
|
||||
def get_transfer_engine(self, hostname: str, device_name: Optional[str]):
|
||||
try:
|
||||
ip = ipaddress.ip_address(hostname)
|
||||
if isinstance(ip, ipaddress.IPv6Address):
|
||||
raise RuntimeError(
|
||||
"The backend of mooncake's Ascend Direct Xfer Library currently does not support IPv6."
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
if self.transfer_engine is None:
|
||||
with self.transfer_engine_lock:
|
||||
# Double-Checked Locking
|
||||
if self.transfer_engine is None:
|
||||
if TransferEngine is None:
|
||||
raise RuntimeError("mooncake is not available")
|
||||
self.transfer_engine = TransferEngine()
|
||||
device_name = device_name if device_name is not None else ""
|
||||
ret_value = self.transfer_engine.initialize(
|
||||
hostname, "P2PHANDSHAKE", "ascend", device_name)
|
||||
if ret_value != 0:
|
||||
raise RuntimeError(
|
||||
f"TransferEngine initialization failed with ret_value: {ret_value}"
|
||||
)
|
||||
return self.transfer_engine
|
||||
|
||||
def register_buffer(self, ptrs: list[int], sizes: list[int]):
|
||||
with self.register_buffer_lock:
|
||||
assert self.transfer_engine is not None, "Transfer engine must be initialized"
|
||||
if self.is_register_buffer:
|
||||
return
|
||||
for ptr, size in zip(ptrs, sizes):
|
||||
ret_value = self.transfer_engine.register_memory(ptr, size)
|
||||
if ret_value != 0:
|
||||
raise RuntimeError("Mooncake memory registration failed.")
|
||||
self.is_register_buffer = True
|
||||
|
||||
|
||||
global_te = GlobalTE()
|
||||
Reference in New Issue
Block a user