diff --git a/pyproject.toml b/pyproject.toml index cad82e5e..b78e5d89 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,34 +51,49 @@ line-length = 120 # Folder to be modified exclude = [ "tests/**", - "vllm_ascend/attention/mla_v1.py", - "vllm_ascend/attention/sfa_v1.py", - "vllm_ascend/core", - "vllm_ascend/distributed", - "vllm_ascend/eplb", - "vllm_ascend/kv_offload", - "vllm_ascend/lora", - "vllm_ascend/model_loader", - "vllm_ascend/ops/fused_moe", + # (3) + "vllm_ascend/attention/*.py", + "vllm_ascend/core/*.py", + "vllm_ascend/distributed/device_communicators/**", + "vllm_ascend/distributed/utils.py", + # (5) + "vllm_ascend/distributed/kv_transfer/kv_pool/**", + "vllm_ascend/distributed/kv_transfer/utils/**", + "vllm_ascend/kv_offload/**", + "vllm_ascend/lora/**", + # (6) + "vllm_ascend/eplb/**", + "vllm_ascend/model_loader/netloader/**", + "vllm_ascend/patch/**", + # (7) + "vllm_ascend/quantization/**", + "vllm_ascend/sample/*.py", + "vllm_ascend/worker/v2/**", + "vllm_ascend/worker/block_table.py", + "vllm_ascend/worker/npu_input_batch.py", + # (8) + "vllm_ascend/ops/__init__.py", "vllm_ascend/ops/activation.py", "vllm_ascend/ops/flashcomm2_oshard_manager.py", - "vllm_ascend/ops/layer_shard_linear.py", "vllm_ascend/ops/layernorm.py", - "vllm_ascend/ops/linear_op.py", - "vllm_ascend/ops/linear.py", "vllm_ascend/ops/mla.py", "vllm_ascend/ops/mm_encoder_attention.py", "vllm_ascend/ops/register_custom_ops.py", "vllm_ascend/ops/rotary_embedding.py", "vllm_ascend/ops/vocab_parallel_embedding.py", "vllm_ascend/ops/weight_prefetch.py", - "vllm_ascend/ops/__init__.py", - "vllm_ascend/patch", - "vllm_ascend/quantization", - "vllm_ascend/sample", - "vllm_ascend/spec_decode", - "vllm_ascend/worker", - "vllm_ascend/xlite", + "vllm_ascend/spec_decode/**", + # (9) + "vllm_ascend/worker/model_runner_v1.py", + "vllm_ascend/worker/pcp_utils.py", + # (10) + "vllm_ascend/ops/*linear*.py", + "vllm_ascend/worker/worker.py", + "vllm_ascend/distributed/parallel_state.py", + "vllm_ascend/distributed/utils.py", + "vllm_ascend/xlite/*.py", + # (11) + "vllm_ascend/ops/fused_moe/**", ] [tool.ruff.lint] diff --git a/vllm_ascend/distributed/kv_transfer/__init__.py b/vllm_ascend/distributed/kv_transfer/__init__.py index f2bf8d6f..0450a104 100644 --- a/vllm_ascend/distributed/kv_transfer/__init__.py +++ b/vllm_ascend/distributed/kv_transfer/__init__.py @@ -15,31 +15,32 @@ # limitations under the License. # -from vllm.distributed.kv_transfer.kv_connector.factory import \ - KVConnectorFactory +from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory def register_connector(): KVConnectorFactory.register_connector( - "MooncakeConnectorV1", - "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector", - "MooncakeConnector") + "MooncakeConnectorV1", "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector", "MooncakeConnector" + ) KVConnectorFactory.register_connector( "MooncakeConnectorStoreV1", "vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.ascend_store_connector", - "AscendStoreConnector") + "AscendStoreConnector", + ) KVConnectorFactory.register_connector( "AscendStoreConnector", "vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.ascend_store_connector", - "AscendStoreConnector") + "AscendStoreConnector", + ) KVConnectorFactory.register_connector( "MooncakeLayerwiseConnector", "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector", - "MooncakeLayerwiseConnector") + "MooncakeLayerwiseConnector", + ) KVConnectorFactory.register_connector( - "UCMConnector", "vllm_ascend.distributed.kv_transfer.kv_pool.ucm_connector", - "UCMConnectorV1") + "UCMConnector", "vllm_ascend.distributed.kv_transfer.kv_pool.ucm_connector", "UCMConnectorV1" + ) diff --git a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_connector.py b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_connector.py index 130d1ae1..d03ba3ed 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_connector.py +++ b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_connector.py @@ -9,11 +9,11 @@ import random import struct import threading import time -from collections import defaultdict, deque +from collections import OrderedDict, defaultdict, deque from collections.abc import Iterator from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, List, Optional, OrderedDict, Tuple, TypedDict +from typing import TYPE_CHECKING, Any, TypedDict import msgspec import numpy as np @@ -26,13 +26,19 @@ from vllm import envs from vllm.config import VllmConfig from vllm.distributed import get_pcp_group from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - KVConnectorBase_V1, KVConnectorHandshakeMetadata, KVConnectorMetadata, - KVConnectorRole) + KVConnectorBase_V1, + KVConnectorHandshakeMetadata, + KVConnectorMetadata, + KVConnectorRole, +) from vllm.distributed.parallel_state import ( get_decode_context_model_parallel_rank, - get_decode_context_model_parallel_world_size, get_pp_group, - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, - get_tp_group) + get_decode_context_model_parallel_world_size, + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + get_tp_group, +) from vllm.distributed.utils import get_pp_indices from vllm.logger import logger from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket @@ -87,7 +93,6 @@ class ReqMeta: @dataclass class SizedDict(OrderedDict): - def __init__(self, max_size=16000, *args, **kwargs): self.max_size = max_size super().__init__(*args, **kwargs) @@ -107,7 +112,6 @@ class SizedDict(OrderedDict): class KVCacheTaskTracker: - def __init__(self): super().__init__() @@ -136,7 +140,8 @@ class KVCacheTaskTracker: self.delayed_free_requests.pop(request_id, None) else: logger.error( - "MooncakeConnector finish req not in reqs to process.If it is a P node, this request may have been force freed." + "MooncakeConnector finish req not in reqs to process." + "If it is a P node, this request may have been force freed." ) def get_and_clear_finished_requests(self) -> set[str]: @@ -166,8 +171,7 @@ class KVCacheTaskTracker: while self.delayed_free_requests: request_id = next(iter(self.delayed_free_requests)) delay_start_time = self.delayed_free_requests[request_id] - if (current_time - delay_start_time - > envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT): + if current_time - delay_start_time > envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT: self.delayed_free_requests.popitem(last=False) self.reqs_to_process.discard(request_id) expired_requests.add(request_id) @@ -178,12 +182,19 @@ class KVCacheTaskTracker: class KVCacheSendingThread(threading.Thread): - - def __init__(self, vllm_config: VllmConfig, tp_rank: int, - prefill_tp_size: int, local_engine_id: str, - side_channel_host: str, side_channel_port: int, - metadata: MooncakeAgentMetadata, ready_event: threading.Event, - kv_caches: dict[str, Any], pcp_rank: int): + def __init__( + self, + vllm_config: VllmConfig, + tp_rank: int, + prefill_tp_size: int, + local_engine_id: str, + side_channel_host: str, + side_channel_port: int, + metadata: MooncakeAgentMetadata, + ready_event: threading.Event, + kv_caches: dict[str, Any], + pcp_rank: int, + ): super().__init__(daemon=True, name="KVCacheSendingThread") self.tp_rank = tp_rank self.prefill_tp_size = prefill_tp_size @@ -213,8 +224,7 @@ class KVCacheSendingThread(threading.Thread): self.task_tracker.add_not_transfer_request(request_id) def add_delayed_request(self, request_id: str, delay_start_time: float): - return self.task_tracker.add_delayed_request(request_id, - delay_start_time) + return self.task_tracker.add_delayed_request(request_id, delay_start_time) def run(self): """Run the thread to handle KV cache transfer requests.""" @@ -231,16 +241,13 @@ class KVCacheSendingThread(threading.Thread): self.ready_event.set() self.run_busy_loop(sock) except Exception as e: - logger.error("Mooncake KVCacheSendingThread exception: %s", - e, - exc_info=True) + logger.error("Mooncake KVCacheSendingThread exception: %s", e, exc_info=True) def run_busy_loop(self, sock: zmq.Socket): # type: ignore encoder = msgspec.msgpack.Encoder() encoded_data = encoder.encode(self.metadata) size_in_bytes = len(encoded_data) - logger.debug("Size of encoded MooncakeAgentMetadata: %s bytes", - str(size_in_bytes)) + logger.debug("Size of encoded MooncakeAgentMetadata: %s bytes", str(size_in_bytes)) decoder = msgspec.msgpack.Decoder(type=tuple) while True: @@ -267,14 +274,10 @@ class KVCacheSendingThread(threading.Thread): if request_id not in self.port_send_num: self.port_send_num[request_id] = 0 self.port_send_num[request_id] += 1 - device_index = self.pp_rank * self.tp_size + \ - self.tp_rank + self.pcp_rank * \ - self.prefill_tp_size + device_index = self.pp_rank * self.tp_size + self.tp_rank + self.pcp_rank * self.prefill_tp_size handshake_port = self.side_channel_port + device_index - if self.port_send_num[request_id] >= \ - remote_port_send_num[handshake_port]['num']: - self.task_tracker.update_done_task_count( - request_id) + if self.port_send_num[request_id] >= remote_port_send_num[handshake_port]["num"]: + self.task_tracker.update_done_task_count(request_id) del self.port_send_num[request_id] else: self.task_tracker.update_done_task_count(request_id) @@ -282,40 +285,35 @@ class KVCacheSendingThread(threading.Thread): while True: try: # Send ACK to the sender. - sock.send_multipart( - (identity, b"", b"ACK"), - flags=zmq.NOBLOCK) # type: ignore + sock.send_multipart((identity, b"", b"ACK"), flags=zmq.NOBLOCK) # type: ignore break except zmq.Again: # type: ignore # If the socket is not ready, retry sending. - logger.debug( - "Socket not ready, retrying to send ACK for " - "request %s", msg[1]) + logger.debug("Socket not ready, retrying to send ACK for request %s", msg[1]) time.sleep(0.01) else: - logger.error( - "Connection listener got unexpected message %s", msg) + logger.error("Connection listener got unexpected message %s", msg) except Exception as e: - logger.error("Connection listener got exception %s: %s", - type(e), e) + logger.error("Connection listener got exception %s: %s", type(e), e) class KVCacheRecvingThread(threading.Thread): - - def __init__(self, - tp_rank: int, - tp_size: int, - _prefill_pp_size: int, - engine: TransferEngine, - local_engine_id: str, - local_handshake_port: int, - side_channel_port: int, - local_kv_caches_base_addr: list[int], - block_len: list[int], - ready_event: threading.Event, - vllm_config: VllmConfig, - kv_caches: dict[str, Any], - prefill_pp_layer_partition: Optional[str] = None): + def __init__( + self, + tp_rank: int, + tp_size: int, + _prefill_pp_size: int, + engine: TransferEngine, + local_engine_id: str, + local_handshake_port: int, + side_channel_port: int, + local_kv_caches_base_addr: list[int], + block_len: list[int], + ready_event: threading.Event, + vllm_config: VllmConfig, + kv_caches: dict[str, Any], + prefill_pp_layer_partition: str | None = None, + ): super().__init__(daemon=True, name="KVCacheRecvingThread") self.tp_rank = tp_rank self.tp_size = tp_size @@ -327,12 +325,9 @@ class KVCacheRecvingThread(threading.Thread): self.ready_event = ready_event self.kv_caches = kv_caches - self.kv_caches_base_addr: dict[str, dict[int, list[int]]] = \ - SizedDict() - self.kv_caches_base_addr[local_engine_id][local_handshake_port] = \ - local_kv_caches_base_addr - self.remote_te_port: dict[str, dict[int, int]] = \ - SizedDict() + self.kv_caches_base_addr: dict[str, dict[int, list[int]]] = SizedDict() + self.kv_caches_base_addr[local_engine_id][local_handshake_port] = local_kv_caches_base_addr + self.remote_te_port: dict[str, dict[int, int]] = SizedDict() self.block_len = block_len # TODO(jianzs): find a better way to detect MLA. self.use_mla = len(block_len) == 2 @@ -347,8 +342,10 @@ class KVCacheRecvingThread(threading.Thread): self.decoder = msgspec.msgpack.Decoder(MooncakeAgentMetadata) self.remote_sockets_lock = threading.Lock() self.remote_sockets: dict[ # type: ignore - str, deque[zmq.Socket]] = defaultdict( # type: ignore - deque) + str, deque[zmq.Socket] + ] = defaultdict( # type: ignore + deque + ) self.remote_poller = zmq.Poller() # type: ignore self.timeout = 1.0 # seconds @@ -357,10 +354,7 @@ class KVCacheRecvingThread(threading.Thread): self.block_size = self.vllm_config.cache_config.block_size self.num_layers = self.model_config.hf_text_config.num_hidden_layers self.pp_layer_indices = { - rank: - get_prefill_pp_indices(self.num_layers, rank, - self._prefill_pp_size, - prefill_pp_layer_partition) + rank: get_prefill_pp_indices(self.num_layers, rank, self._prefill_pp_size, prefill_pp_layer_partition) for rank in range(self._prefill_pp_size) } if not is_vl_model(vllm_config): @@ -371,38 +365,42 @@ class KVCacheRecvingThread(threading.Thread): else: self.k_head_dim = self.model_config.hf_text_config.head_dim self.v_head_dim = self.model_config.hf_text_config.head_dim - self.num_kv_heads = max( - self.model_config.hf_text_config.num_key_value_heads // - self.tp_size, 1) + self.num_kv_heads = max(self.model_config.hf_text_config.num_key_value_heads // self.tp_size, 1) self.proc_not_transfer_request: dict[str, bool] = {} - def add_request(self, - request_id: str, - remote_request_id: str, - local_block_ids: list[int], - remote_block_ids: list[int], - remote_engine_id: str, - remote_host: str, - remote_handshake_port: int, - offset: int, - tp_num_need_pulls: int, - remote_port_send_num: dict[int, RemotePortInfo] = {}, - all_task_done: bool = False): + def add_request( + self, + request_id: str, + remote_request_id: str, + local_block_ids: list[int], + remote_block_ids: list[int], + remote_engine_id: str, + remote_host: str, + remote_handshake_port: int, + offset: int, + tp_num_need_pulls: int, + remote_port_send_num: dict[int, RemotePortInfo] | None = None, + all_task_done: bool = False, + ): """Add a new request to the queue for processing.""" + if remote_port_send_num is None: + remote_port_send_num = {} logger.debug(f"Adding request {request_id} to the queue.") - self.request_queue.put({ - "request_id": request_id, - "local_block_ids": local_block_ids, - "remote_block_ids": remote_block_ids, - "remote_engine_id": remote_engine_id, - "remote_request_id": remote_request_id, - "remote_host": remote_host, - "remote_handshake_port": remote_handshake_port, - "offset": offset, - "tp_num_need_pulls": tp_num_need_pulls, - "remote_port_send_num": remote_port_send_num, - "all_task_done": all_task_done - }) + self.request_queue.put( + { + "request_id": request_id, + "local_block_ids": local_block_ids, + "remote_block_ids": remote_block_ids, + "remote_engine_id": remote_engine_id, + "remote_request_id": remote_request_id, + "remote_host": remote_host, + "remote_handshake_port": remote_handshake_port, + "offset": offset, + "tp_num_need_pulls": tp_num_need_pulls, + "remote_port_send_num": remote_port_send_num, + "all_task_done": all_task_done, + } + ) def get_and_clear_finished_requests(self) -> set[str]: """ @@ -435,19 +433,13 @@ class KVCacheRecvingThread(threading.Thread): all_task_done = req_meta["all_task_done"] try: - logger.debug( - f"Starting to transfer KV cache for request {remote_request_id}.") + logger.debug(f"Starting to transfer KV cache for request {remote_request_id}.") self._transfer_kv_cache(req_meta) - logger.debug( - f"Finished transferring KV cache for request {remote_request_id}.") + logger.debug(f"Finished transferring KV cache for request {remote_request_id}.") except Exception as e: - logger.error( - "Failed to transfer KV cache for request " - f"{remote_request_id}: {e}", - exc_info=True) + logger.error(f"Failed to transfer KV cache for request {remote_request_id}: {e}", exc_info=True) finally: - self._send_done_signal_to_free_remote_port(remote_request_id, remote_host, - remote_port_send_num) + self._send_done_signal_to_free_remote_port(remote_request_id, remote_host, remote_port_send_num) if all_task_done: self.task_tracker.update_done_task_count(request_id) if request_id in self.proc_not_transfer_request: @@ -456,25 +448,20 @@ class KVCacheRecvingThread(threading.Thread): # Always send the done signal to the remote host to ensure proper # resource cleanup. Failing to do so may cause a memory leak on the # remote host. - self._send_done_recv_signal(remote_request_id, remote_host, - remote_handshake_port, - remote_port_send_num) + self._send_done_recv_signal(remote_request_id, remote_host, remote_handshake_port, remote_port_send_num) def _send_done_signal_to_free_remote_port( - self, request_id: str, remote_host: str, - remote_port_send_num: dict[int, RemotePortInfo]): - if self.side_channel_port != self.local_handshake_port \ - or not remote_port_send_num: + self, request_id: str, remote_host: str, remote_port_send_num: dict[int, RemotePortInfo] + ): + if self.side_channel_port != self.local_handshake_port or not remote_port_send_num: return if request_id not in self.proc_not_transfer_request: self.proc_not_transfer_request[request_id] = True if self.proc_not_transfer_request[request_id]: - for remote_port in remote_port_send_num.keys(): - if remote_port_send_num[remote_port]['num'] == 0: - remote_host_ = remote_port_send_num[remote_port]['host'] - self._send_done_recv_signal(request_id, remote_host_, - remote_port, - remote_port_send_num) + for remote_port in remote_port_send_num: + if remote_port_send_num[remote_port]["num"] == 0: + remote_host_ = remote_port_send_num[remote_port]["host"] + self._send_done_recv_signal(request_id, remote_host_, remote_port, remote_port_send_num) self.proc_not_transfer_request[request_id] = False def _transfer_kv_cache(self, req_meta: dict[str, Any]): @@ -500,13 +487,16 @@ class KVCacheRecvingThread(threading.Thread): remote_block_ids = remote_block_ids[-num_local_blocks:] # Check if we have the remote metadata cached. - if remote_engine_id not in self.kv_caches_base_addr or \ - remote_handshake_port not in self.kv_caches_base_addr[remote_engine_id]: + if ( + remote_engine_id not in self.kv_caches_base_addr + or remote_handshake_port not in self.kv_caches_base_addr[remote_engine_id] + ): self._get_remote_metadata(remote_host, remote_handshake_port) if tp_num_need_pulls == 1: - grouped_remote_block_ids, grouped_local_block_ids = \ - group_concurrent_contiguous(remote_block_ids, local_block_ids) + grouped_remote_block_ids, grouped_local_block_ids = group_concurrent_contiguous( + remote_block_ids, local_block_ids + ) else: remote_block_ids = list(map(lambda x: [x], remote_block_ids)) local_block_ids = list(map(lambda x: [x], local_block_ids)) @@ -519,53 +509,45 @@ class KVCacheRecvingThread(threading.Thread): prefill_pp_rank = offset // tp_num_need_pulls # PP rank where current request resides inner_offset = offset % tp_num_need_pulls # Offset within each PP stage - remote_kv_caches_base_addrs = \ - self.kv_caches_base_addr[remote_engine_id][remote_handshake_port] - first_layer_index, end_layer_index = self.pp_layer_indices[ - prefill_pp_rank] + remote_kv_caches_base_addrs = self.kv_caches_base_addr[remote_engine_id][remote_handshake_port] + first_layer_index, end_layer_index = self.pp_layer_indices[prefill_pp_rank] # support MTP layer kv transfer if self.vllm_config.speculative_config is not None: # all MTP layer use the same kv cache layer, so only need to transfer once if prefill_pp_rank == self._prefill_pp_size - 1: end_layer_index = end_layer_index + 1 - num_cache_per_layer = len(list( - self.kv_caches.values())[0]) # Number of KV caches per layer - local_kv_caches_base_addrs = \ - self.kv_caches_base_addr[self.local_engine_id][self.local_handshake_port][first_layer_index*num_cache_per_layer : end_layer_index*num_cache_per_layer] - logger.debug( - f"transfer kv cache first_layer_index:{first_layer_index} , end_layer_index:{end_layer_index}" - ) - remote_transfer_port = self.remote_te_port[remote_engine_id][ - remote_handshake_port] + num_cache_per_layer = len(list(self.kv_caches.values())[0]) # Number of KV caches per layer + local_kv_caches_base_addrs = self.kv_caches_base_addr[self.local_engine_id][self.local_handshake_port][ + first_layer_index * num_cache_per_layer : end_layer_index * num_cache_per_layer + ] + logger.debug(f"transfer kv cache first_layer_index:{first_layer_index} , end_layer_index:{end_layer_index}") + remote_transfer_port = self.remote_te_port[remote_engine_id][remote_handshake_port] num_blocks = len(local_block_ids) session_id = f"{remote_host}:{remote_transfer_port}" req_start_time = time.perf_counter() src_list, dst_list, length_list = [], [], [] for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate( - zip(local_kv_caches_base_addrs, remote_kv_caches_base_addrs)): + zip(local_kv_caches_base_addrs, remote_kv_caches_base_addrs) + ): if self.use_mla: - block_len = (self.block_len[k % 2]) + block_len = self.block_len[k % 2] elif self.use_sparse: - block_len = (self.block_len[k % 3]) + block_len = self.block_len[k % 3] else: - block_len = (self.block_len[0]) + block_len = self.block_len[0] inner_block_len = block_len // tp_num_need_pulls - for remote_block_id, local_block_id in zip( - grouped_remote_block_ids, grouped_local_block_ids): - src = src_layer_base_addr + local_block_id[ - 0] * block_len + inner_offset * inner_block_len + for remote_block_id, local_block_id in zip(grouped_remote_block_ids, grouped_local_block_ids): + src = src_layer_base_addr + local_block_id[0] * block_len + inner_offset * inner_block_len dst = dst_layer_base_addr + remote_block_id[0] * inner_block_len length = inner_block_len * len(local_block_id) src_list.append(src) dst_list.append(dst) length_list.append(length) - ret = self.engine.batch_transfer_sync_read(session_id, src_list, - dst_list, length_list) + ret = self.engine.batch_transfer_sync_read(session_id, src_list, dst_list, length_list) if ret < 0: - logger.error("Mooncake transfer failed for request %s", - req_meta["remote_request_id"]) + logger.error("Mooncake transfer failed for request %s", req_meta["remote_request_id"]) raise RuntimeError(f"Mooncake transfer failed, ret: {ret}") req_end_time = time.perf_counter() @@ -573,61 +555,54 @@ class KVCacheRecvingThread(threading.Thread): logger.info( "KV cache transfer for request %s took %.2f ms (%d groups," " %d blocks). local_ip %s local_device_id %s remote_session_id %s", - remote_request_id, req_transfer_elapsed, num_transfer_groups, num_blocks, - get_ip(), self.tp_rank, session_id) + remote_request_id, + req_transfer_elapsed, + num_transfer_groups, + num_blocks, + get_ip(), + self.tp_rank, + session_id, + ) # Determine if the current position is the offset position at the end of # the KV transmission. - is_kv_transfer_end = ( - global_offset == tp_num_need_pulls * self._prefill_pp_size - 1) + is_kv_transfer_end = global_offset == tp_num_need_pulls * self._prefill_pp_size - 1 need_cat_cache = tp_num_need_pulls > 1 and is_kv_transfer_end need_nz_cache = get_ascend_config().enable_kv_nz and is_kv_transfer_end if need_nz_cache or need_cat_cache: - self.reformat_kv_cache(grouped_local_block_ids, tp_num_need_pulls, - need_cat_cache, need_nz_cache) + self.reformat_kv_cache(grouped_local_block_ids, tp_num_need_pulls, need_cat_cache, need_nz_cache) - def reformat_kv_cache(self, - block_ids: list[list[int]], - tp_num_need_pulls: int, - need_cat_cache: bool = False, - need_nz_cache: bool = False): + def reformat_kv_cache( + self, + block_ids: list[list[int]], + tp_num_need_pulls: int, + need_cat_cache: bool = False, + need_nz_cache: bool = False, + ): # Get necessary parameters k_cache = list(self.kv_caches.values())[0][0] dtype = k_cache.dtype device = k_cache.device flat_block_ids = [item for sublist in block_ids for item in sublist] - block_ids_tensor = torch.tensor(flat_block_ids, - dtype=torch.int32, - device=device) + block_ids_tensor = torch.tensor(flat_block_ids, dtype=torch.int32, device=device) num_blocks = len(flat_block_ids) num_tokens = num_blocks * self.block_size # Create device tensors for copy operations block_table = block_ids_tensor.view(1, -1) - block_len_tensor = torch.tensor([num_tokens], - dtype=torch.int32, - device=device) + block_len_tensor = torch.tensor([num_tokens], dtype=torch.int32, device=device) seq_start_tensor = torch.tensor([0], dtype=torch.int32, device=device) # Initialize buffers - k_buffer = torch.empty( - (num_tokens, self.num_kv_heads, self.k_head_dim), - dtype=dtype, - device=device) - v_buffer = torch.empty( - (num_tokens, self.num_kv_heads, self.v_head_dim), - dtype=dtype, - device=device) + k_buffer = torch.empty((num_tokens, self.num_kv_heads, self.k_head_dim), dtype=dtype, device=device) + v_buffer = torch.empty((num_tokens, self.num_kv_heads, self.v_head_dim), dtype=dtype, device=device) # Create slot mapping for reshape operations - block_offsets = torch.arange(0, - self.block_size, - dtype=torch.int32, - device=device) - slot_mapping = (block_offsets.reshape( - (1, self.block_size)) + block_ids_tensor.reshape( - (num_blocks, 1)) * self.block_size).flatten() + block_offsets = torch.arange(0, self.block_size, dtype=torch.int32, device=device) + slot_mapping = ( + block_offsets.reshape((1, self.block_size)) + block_ids_tensor.reshape((num_blocks, 1)) * self.block_size + ).flatten() # FIXME: Right now, if we skip synchronization at this point, the system # will crash in GQA scenarios. However, we still haven't identified the @@ -637,30 +612,36 @@ class KVCacheRecvingThread(threading.Thread): # Process each layer in the KV cache for _, (k_cache_layer, v_cache_layer) in self.kv_caches.items(): # Load cache data into buffers - torch_npu.atb.npu_paged_cache_load(k_cache_layer, - v_cache_layer, - block_table, - block_len_tensor, - seq_starts=seq_start_tensor, - key=k_buffer, - value=v_buffer) + torch_npu.atb.npu_paged_cache_load( + k_cache_layer, + v_cache_layer, + block_table, + block_len_tensor, + seq_starts=seq_start_tensor, + key=k_buffer, + value=v_buffer, + ) if need_cat_cache: - self._cat_kv_cache(k_cache_layer, v_cache_layer, k_buffer, - v_buffer, tp_num_need_pulls, num_blocks, - num_tokens, slot_mapping) + self._cat_kv_cache( + k_cache_layer, + v_cache_layer, + k_buffer, + v_buffer, + tp_num_need_pulls, + num_blocks, + num_tokens, + slot_mapping, + ) if need_nz_cache: - self._nz_kv_cache(k_cache_layer, v_cache_layer, k_buffer, - v_buffer, slot_mapping) + self._nz_kv_cache(k_cache_layer, v_cache_layer, k_buffer, v_buffer, slot_mapping) # Clean up buffers del k_buffer, v_buffer - def _cat_kv_cache(self, k_cache_layer, v_cache_layer, k_buffer, v_buffer, - tp_num_need_pulls, num_blocks, num_tokens, slot_mapping): - - def _transpose_kv_cache_between_head( - buffer: torch.Tensor) -> torch.Tensor: - buffer = buffer.view(num_blocks, tp_num_need_pulls, - self.block_size, -1) + def _cat_kv_cache( + self, k_cache_layer, v_cache_layer, k_buffer, v_buffer, tp_num_need_pulls, num_blocks, num_tokens, slot_mapping + ): + def _transpose_kv_cache_between_head(buffer: torch.Tensor) -> torch.Tensor: + buffer = buffer.view(num_blocks, tp_num_need_pulls, self.block_size, -1) buffer.transpose_(1, 2) return buffer.contiguous().view(num_tokens, self.num_kv_heads, -1) @@ -669,28 +650,23 @@ class KVCacheRecvingThread(threading.Thread): v_buffer = _transpose_kv_cache_between_head(v_buffer) # Reshape and cache the processed buffers - torch_npu._npu_reshape_and_cache(key=k_buffer, - value=v_buffer, - key_cache=k_cache_layer, - value_cache=v_cache_layer, - slot_indices=slot_mapping) + torch_npu._npu_reshape_and_cache( + key=k_buffer, value=v_buffer, key_cache=k_cache_layer, value_cache=v_cache_layer, slot_indices=slot_mapping + ) - def _nz_kv_cache(self, k_cache_layer, v_cache_layer, k_buffer, v_buffer, - slot_mapping): + def _nz_kv_cache(self, k_cache_layer, v_cache_layer, k_buffer, v_buffer, slot_mapping): nz_fmt_last_dim = 16 k_cache_layer = k_cache_layer.view( - -1, self.k_head_dim * self.num_kv_heads // nz_fmt_last_dim, - self.block_size, nz_fmt_last_dim) + -1, self.k_head_dim * self.num_kv_heads // nz_fmt_last_dim, self.block_size, nz_fmt_last_dim + ) v_cache_layer = v_cache_layer.view( - -1, self.v_head_dim * self.num_kv_heads // nz_fmt_last_dim, - self.block_size, nz_fmt_last_dim) - torch_npu.npu_scatter_pa_kv_cache(k_buffer, v_buffer, k_cache_layer, - v_cache_layer, slot_mapping) + -1, self.v_head_dim * self.num_kv_heads // nz_fmt_last_dim, self.block_size, nz_fmt_last_dim + ) + torch_npu.npu_scatter_pa_kv_cache(k_buffer, v_buffer, k_cache_layer, v_cache_layer, slot_mapping) - def _get_remote_metadata(self, remote_host: str, - remote_handshake_port: int) -> None: + def _get_remote_metadata(self, remote_host: str, remote_handshake_port: int) -> None: """Get the metadata from the remote host.""" - sock: Optional[zmq.Socket] = None # type: ignore + sock: zmq.Socket | None = None # type: ignore try: sock = self._get_remote_socket(remote_host, remote_handshake_port) ensure_zmq_send(sock, self.encoder.encode((GET_META_MSG, "")), f"{remote_host}:{remote_handshake_port}") @@ -698,60 +674,50 @@ class KVCacheRecvingThread(threading.Thread): agent_meta = self.decoder.decode(metadata_bytes) engine_id = agent_meta.engine_id assert engine_id != self.local_engine_id, ( - f"Conflict engine id {engine_id} with local engine id " - f"{self.local_engine_id}.") - self.kv_caches_base_addr[engine_id][remote_handshake_port] = \ - agent_meta.kv_caches_base_addr - self.remote_te_port[engine_id][remote_handshake_port] = \ - agent_meta.te_rpc_port + f"Conflict engine id {engine_id} with local engine id {self.local_engine_id}." + ) + self.kv_caches_base_addr[engine_id][remote_handshake_port] = agent_meta.kv_caches_base_addr + self.remote_te_port[engine_id][remote_handshake_port] = agent_meta.te_rpc_port finally: if sock is not None: - self._return_remote_socket(sock, remote_host, - remote_handshake_port) - logger.debug("Returned socket to pool for %s:%d", remote_host, - remote_handshake_port) + self._return_remote_socket(sock, remote_host, remote_handshake_port) + logger.debug("Returned socket to pool for %s:%d", remote_host, remote_handshake_port) def _send_done_recv_signal( - self, request_id: str, remote_host: str, - remote_handshake_port: int, - remote_port_send_num: dict[int, RemotePortInfo]): - logger.debug("Sending done recving signal for request %s to %s:%d", - request_id, remote_host, remote_handshake_port) - sock: Optional[zmq.Socket] = None # type: ignore + self, + request_id: str, + remote_host: str, + remote_handshake_port: int, + remote_port_send_num: dict[int, RemotePortInfo], + ): + logger.debug( + "Sending done recving signal for request %s to %s:%d", request_id, remote_host, remote_handshake_port + ) + sock: zmq.Socket | None = None # type: ignore try: sock = self._get_remote_socket(remote_host, remote_handshake_port) - data_bytes = self.encoder.encode( - (DONE_RECVING_MSG, request_id, remote_port_send_num)) + data_bytes = self.encoder.encode((DONE_RECVING_MSG, request_id, remote_port_send_num)) ensure_zmq_send(sock, data_bytes, f"{remote_host}:{remote_handshake_port}") - resp = ensure_zmq_recv(sock, - self.remote_poller, - f"{remote_host}:{remote_handshake_port}", - timeout=self.timeout) - logger.debug( - f"Received response for request {request_id}: {resp.decode('utf-8')}" + resp = ensure_zmq_recv( + sock, self.remote_poller, f"{remote_host}:{remote_handshake_port}", timeout=self.timeout ) + logger.debug(f"Received response for request {request_id}: {resp.decode('utf-8')}") if resp != b"ACK": - logger.error("Failed to receive ACK for request %s from %s:%d", - request_id, remote_host, remote_handshake_port) - raise RuntimeError( - f"Failed to receive ACK, resp: {resp.decode('utf-8')}") + logger.error( + "Failed to receive ACK for request %s from %s:%d", request_id, remote_host, remote_handshake_port + ) + raise RuntimeError(f"Failed to receive ACK, resp: {resp.decode('utf-8')}") except RuntimeError as e: if isinstance(sock, zmq.Socket): # type: ignore sock.close() sock = None - logger.warning( - f"Unexpected error occurred in socket, {e}, closing the original channel" - ) + logger.warning(f"Unexpected error occurred in socket, {e}, closing the original channel") finally: if sock is not None: - self._return_remote_socket(sock, remote_host, - remote_handshake_port) - logger.debug("Returned socket to pool for %s:%d", remote_host, - remote_handshake_port) + self._return_remote_socket(sock, remote_host, remote_handshake_port) + logger.debug("Returned socket to pool for %s:%d", remote_host, remote_handshake_port) - def _get_remote_socket( - self, remote_host: str, - remote_handshake_port: int) -> zmq.Socket: # type: ignore + def _get_remote_socket(self, remote_host: str, remote_handshake_port: int) -> zmq.Socket: # type: ignore """Get a socket to the remote host.""" remote_path = make_zmq_path("tcp", remote_host, remote_handshake_port) with self.remote_sockets_lock: @@ -763,18 +729,21 @@ class KVCacheRecvingThread(threading.Thread): ctx=ctx, path=remote_path, socket_type=zmq.REQ, # type: ignore - bind=False) + bind=False, + ) sock.setsockopt( zmq.SNDTIMEO, # type: ignore - int(self.timeout * 1000)) + int(self.timeout * 1000), + ) self.remote_poller.register(sock, zmq.POLLIN) # type: ignore return sock def _return_remote_socket( - self, - sock: zmq.Socket, # type: ignore - remote_host: str, - remote_handshake_port: int) -> None: + self, + sock: zmq.Socket, # type: ignore + remote_host: str, + remote_handshake_port: int, + ) -> None: """Return the remote socket to the pool.""" remote_path = make_zmq_path("tcp", remote_host, remote_handshake_port) with self.remote_sockets_lock: @@ -782,7 +751,6 @@ class KVCacheRecvingThread(threading.Thread): class MooncakeConnectorMetadata(KVConnectorMetadata): - def __init__(self): self.requests: dict[str, ReqMeta] = {} self.requests_to_send: dict[str, float] = {} @@ -805,48 +773,37 @@ class MooncakeConnectorMetadata(KVConnectorMetadata): remote_port=kv_transfer_params["remote_port"], remote_pcp_size=kv_transfer_params.get("remote_pcp_size", 1), remote_dcp_size=kv_transfer_params.get("remote_dcp_size", 1), - remote_multi_nodes_meta_mapping=kv_transfer_params.get( - "remote_multi_nodes_meta_mapping", {}), + remote_multi_nodes_meta_mapping=kv_transfer_params.get("remote_multi_nodes_meta_mapping", {}), num_prompt_blocks=kv_transfer_params.get("num_prompt_blocks", 0), ) class MooncakeConnector(KVConnectorBase_V1): - - def __init__(self, - vllm_config: VllmConfig, - role: KVConnectorRole, - kv_cache_config: Optional[KVCacheConfig] = None): + def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole, kv_cache_config: KVCacheConfig | None = None): assert vllm_config.kv_transfer_config is not None self.engine_id = vllm_config.kv_transfer_config.engine_id self._connector_metadata = MooncakeConnectorMetadata() if role == KVConnectorRole.SCHEDULER: - self.connector_scheduler: Optional[MooncakeConnectorScheduler] = \ - MooncakeConnectorScheduler(vllm_config, str(self.engine_id)) - self.connector_worker: Optional[MooncakeConnectorWorker] = None + self.connector_scheduler: MooncakeConnectorScheduler | None = MooncakeConnectorScheduler( + vllm_config, str(self.engine_id) + ) + self.connector_worker: MooncakeConnectorWorker | None = None elif role == KVConnectorRole.WORKER: self.connector_scheduler = None - self.connector_worker = MooncakeConnectorWorker( - vllm_config, str(self.engine_id)) + self.connector_worker = MooncakeConnectorWorker(vllm_config, str(self.engine_id)) ############################################################ # Scheduler Side Methods ############################################################ - def get_num_new_matched_tokens( - self, request: "Request", - num_computed_tokens: int) -> tuple[int, bool]: + def get_num_new_matched_tokens(self, request: "Request", num_computed_tokens: int) -> tuple[int, bool]: assert self.connector_scheduler is not None - return self.connector_scheduler.get_num_new_matched_tokens( - request, num_computed_tokens) + return self.connector_scheduler.get_num_new_matched_tokens(request, num_computed_tokens) - def update_state_after_alloc(self, request: "Request", - blocks: "KVCacheBlocks", - num_external_tokens: int): + def update_state_after_alloc(self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int): assert self.connector_scheduler is not None - return self.connector_scheduler.update_state_after_alloc( - request, blocks, num_external_tokens) + return self.connector_scheduler.update_state_after_alloc(request, blocks, num_external_tokens) def build_connector_meta( self, @@ -859,7 +816,7 @@ class MooncakeConnector(KVConnectorBase_V1): self, request: "Request", block_ids: list[int], - ) -> tuple[bool, Optional[dict[str, Any]]]: + ) -> tuple[bool, dict[str, Any] | None]: assert self.connector_scheduler is not None return self.connector_scheduler.request_finished(request, block_ids) @@ -870,14 +827,12 @@ class MooncakeConnector(KVConnectorBase_V1): assert self.connector_worker is not None self.connector_worker.register_kv_caches(kv_caches) - def get_finished(self, - finished_req_ids: set[str]) -> tuple[set[str], set[str]]: + def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]: """Get the finished recving and sending requests.""" assert self.connector_worker is not None return self.connector_worker.get_finished() - def start_load_kv(self, forward_context: "ForwardContext", - **kwargs) -> None: + def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: assert self.connector_worker is not None assert isinstance(self._connector_metadata, MooncakeConnectorMetadata) self.connector_worker.start_load_kv(self._connector_metadata) @@ -886,8 +841,9 @@ class MooncakeConnector(KVConnectorBase_V1): """MooncakeConnector does not do layerwise saving.""" pass - def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, - attn_metadata: "AttentionMetadata", **kwargs) -> None: + def save_kv_layer( + self, layer_name: str, kv_layer: torch.Tensor, attn_metadata: "AttentionMetadata", **kwargs + ) -> None: """MooncakeConnector does not save explicitly.""" pass @@ -908,8 +864,7 @@ class MooncakeConnector(KVConnectorBase_V1): assert self.connector_worker is not None return self.connector_worker.xfer_handshake_metadata - def set_xfer_handshake_metadata( - self, metadata: dict[int, KVConnectorHandshakeMetadata]) -> None: + def set_xfer_handshake_metadata(self, metadata: dict[int, KVConnectorHandshakeMetadata]) -> None: """ Set the KV connector handshake metadata for this connector. @@ -935,17 +890,21 @@ class MooncakeConnectorScheduler: self.side_channel_host = get_ip() self.pcp_size = vllm_config.parallel_config.prefill_context_parallel_size self.dcp_size = vllm_config.parallel_config.decode_context_parallel_size - self.max_device_id = vllm_config.parallel_config.tensor_parallel_size * \ - vllm_config.parallel_config.data_parallel_size * \ - self.pcp_size * \ - vllm_config.parallel_config.pipeline_parallel_size + self.max_device_id = ( + vllm_config.parallel_config.tensor_parallel_size + * vllm_config.parallel_config.data_parallel_size + * self.pcp_size + * vllm_config.parallel_config.pipeline_parallel_size + ) # Handshake base port self.side_channel_port = ( - vllm_config.kv_transfer_config.kv_port + - vllm_config.parallel_config.data_parallel_rank * - vllm_config.parallel_config.tensor_parallel_size * - vllm_config.parallel_config.pipeline_parallel_size * self.pcp_size) + vllm_config.kv_transfer_config.kv_port + + vllm_config.parallel_config.data_parallel_rank + * vllm_config.parallel_config.tensor_parallel_size + * vllm_config.parallel_config.pipeline_parallel_size + * self.pcp_size + ) # Requests that need to start recv. # New requests are added by update_state_after_alloc in # the scheduler. Used to make metadata passed to Worker. @@ -956,9 +915,7 @@ class MooncakeConnectorScheduler: # master-slave meta information for cross-nodes self.multi_nodes_meta_mapping: dict[str, dict[str, Any]] = {} - def get_num_new_matched_tokens( - self, request: "Request", - num_computed_tokens: int) -> tuple[int, bool]: + def get_num_new_matched_tokens(self, request: "Request", num_computed_tokens: int) -> tuple[int, bool]: """ For remote prefill, pull all prompt blocks from remote asynchronously relative to engine execution. @@ -976,9 +933,10 @@ class MooncakeConnectorScheduler: params = request.kv_transfer_params logger.debug( - "MooncakeConnector get_num_new_matched_tokens: " - "num_computed_tokens=%s, kv_transfer_params=%s", - num_computed_tokens, params) + "MooncakeConnector get_num_new_matched_tokens: num_computed_tokens=%s, kv_transfer_params=%s", + num_computed_tokens, + params, + ) if params is not None and params.get("do_remote_prefill"): # Remote prefill: get all prompt blocks from remote. @@ -990,32 +948,24 @@ class MooncakeConnectorScheduler: # No remote prefill for this request. return 0, False - def update_state_after_alloc(self, request: "Request", - blocks: "KVCacheBlocks", - num_external_tokens: int): - + def update_state_after_alloc(self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int): params = request.kv_transfer_params logger.debug( - "MooncakeConnector update_state_after_alloc: " - "num_external_tokens=%s, kv_transfer_params=%s", - num_external_tokens, params) + "MooncakeConnector update_state_after_alloc: num_external_tokens=%s, kv_transfer_params=%s", + num_external_tokens, + params, + ) - if params is not None and (params.get("do_remote_prefill", False) - or params.get("do_remote_decode", False)): + if params is not None and (params.get("do_remote_prefill", False) or params.get("do_remote_decode", False)): self._reqs_in_batch.add(request.request_id) if params is not None and params.get("do_remote_prefill"): if params.get("remote_block_ids"): - if all(p in params for p in ("remote_engine_id", "remote_host", - "remote_port", "remote_request_id")): - local_block_ids = (blocks.get_unhashed_block_ids() - if num_external_tokens > 0 else []) + if all(p in params for p in ("remote_engine_id", "remote_host", "remote_port", "remote_request_id")): + local_block_ids = blocks.get_unhashed_block_ids() if num_external_tokens > 0 else [] # Get unhashed blocks to pull from remote. - self._reqs_need_recv[request.request_id] = ( - request, local_block_ids, num_external_tokens) + self._reqs_need_recv[request.request_id] = (request, local_block_ids, num_external_tokens) else: - logger.warning( - "Got invalid KVTransferParams: %s. This " - "request will not utilize KVTransfer", params) + logger.warning("Got invalid KVTransferParams: %s. This request will not utilize KVTransfer", params) else: assert num_external_tokens == 0 # Only trigger 1 KV transfer per request. @@ -1028,8 +978,7 @@ class MooncakeConnectorScheduler: meta = MooncakeConnectorMetadata() # Loop through scheduled reqs and convert to ReqMeta. - for req_id, (req, block_ids, - num_external_tokens) in self._reqs_need_recv.items(): + for req_id, (req, block_ids, num_external_tokens) in self._reqs_need_recv.items(): assert req.kv_transfer_params is not None # For the case where there are no remote blocks to pull # (block_ids is empty), we don't need to schedule @@ -1054,7 +1003,7 @@ class MooncakeConnectorScheduler: self, request: "Request", block_ids: list[int], - ) -> tuple[bool, Optional[dict[str, Any]]]: + ) -> tuple[bool, dict[str, Any] | None]: """ Once a request is finished, determine whether request blocks should be freed now or will be sent asynchronously and freed later. @@ -1062,22 +1011,23 @@ class MooncakeConnectorScheduler: params = request.kv_transfer_params logger.debug( - "MooncakeConnector request_finished, request_status=%s, " - "kv_transfer_params=%s", request.status, params) + "MooncakeConnector request_finished, request_status=%s, kv_transfer_params=%s", request.status, params + ) - if (params is None or not params.get("do_remote_decode") - or request.status != RequestStatus.FINISHED_LENGTH_CAPPED): + if ( + params is None + or not params.get("do_remote_decode") + or request.status != RequestStatus.FINISHED_LENGTH_CAPPED + ): return False, None computed_block_ids = block_ids delay_free_blocks = len(computed_block_ids) > 0 if delay_free_blocks: - logger.info("Delaying free of %d blocks for request %s", - len(computed_block_ids), request.request_id) + logger.info("Delaying free of %d blocks for request %s", len(computed_block_ids), request.request_id) self._reqs_need_send[request.request_id] = time.time() - num_prompt_blocks = math.ceil( - len(request.prompt_token_ids) / self.block_size) + num_prompt_blocks = math.ceil(len(request.prompt_token_ids) / self.block_size) return delay_free_blocks, dict( do_remote_prefill=True, @@ -1094,8 +1044,7 @@ class MooncakeConnectorScheduler: num_prompt_blocks=num_prompt_blocks, ) - def set_xfer_handshake_metadata( - self, metadata: dict[int, KVConnectorHandshakeMetadata]) -> None: + def set_xfer_handshake_metadata(self, metadata: dict[int, KVConnectorHandshakeMetadata]) -> None: """ Set the KV connector handshake metadata for this connector. @@ -1114,12 +1063,12 @@ class MooncakeConnectorWorker: def __init__(self, vllm_config: VllmConfig, engine_id: str): self._get_prefill_decode_size(vllm_config) - os.environ["ASCEND_TRANSFER_TIMEOUT"] = str( - get_transfer_timeout_value()) + os.environ["ASCEND_TRANSFER_TIMEOUT"] = str(get_transfer_timeout_value()) if self._prefill_tp_size < self._decode_tp_size: raise ValueError( f"prefill_tp_size: {self._prefill_tp_size} must be greater than" - f" or equal to the decode_tp_size: {self._decode_tp_size}") + f" or equal to the decode_tp_size: {self._decode_tp_size}" + ) # Metadata. self.vllm_config = vllm_config @@ -1136,13 +1085,10 @@ class MooncakeConnectorWorker: self.side_channel_host = get_ip() self.pcp_size = get_pcp_group().world_size # Assert that pp_size and pcp_size cannot both be greater than 1 - assert not (self.pp_size > 1 and self.pcp_size - > 1), "pp and pcp cannot open in same time" - self.pcp_rank = get_pcp_group( - ).rank_in_group if self.pcp_size > 1 else 0 + assert not (self.pp_size > 1 and self.pcp_size > 1), "pp and pcp cannot open in same time" + self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0 self.dcp_size = get_decode_context_model_parallel_world_size() - self.dcp_rank = get_decode_context_model_parallel_rank( - ) if self.dcp_size > 1 else 0 + self.dcp_rank = get_decode_context_model_parallel_rank() if self.dcp_size > 1 else 0 self.max_device_id = self.tp_size * self.dp_size * self.pcp_size * self.pp_size self.kv_role = vllm_config.kv_transfer_config.kv_role @@ -1150,21 +1096,21 @@ class MooncakeConnectorWorker: # Handshake base port self.side_channel_port = ( - vllm_config.kv_transfer_config.kv_port + - vllm_config.parallel_config.data_parallel_rank * - vllm_config.parallel_config.tensor_parallel_size * - vllm_config.parallel_config.pipeline_parallel_size * self.pcp_size) - device_index = (self.pp_rank + - self.pcp_rank) * self.tp_size + self.tp_rank + vllm_config.kv_transfer_config.kv_port + + vllm_config.parallel_config.data_parallel_rank + * vllm_config.parallel_config.tensor_parallel_size + * vllm_config.parallel_config.pipeline_parallel_size + * self.pcp_size + ) + device_index = (self.pp_rank + self.pcp_rank) * self.tp_size + self.tp_rank self.handshake_port = self.side_channel_port + device_index self.sockets: dict = {} - self.engine = global_te.get_transfer_engine(self.side_channel_host, - device_name=None) + self.engine = global_te.get_transfer_engine(self.side_channel_host, device_name=None) self.te_rpc_port = self.engine.get_rpc_port() # Background thread for sending or receiving KV caches. - self.kv_send_thread: Optional[KVCacheSendingThread] = None - self.kv_recv_thread: Optional[KVCacheRecvingThread] = None + self.kv_send_thread: KVCacheSendingThread | None = None + self.kv_recv_thread: KVCacheRecvingThread | None = None # Handshake metadata of this worker self.xfer_handshake_metadata: MooncakeAgentMetadata | None = None @@ -1175,41 +1121,33 @@ class MooncakeConnectorWorker: if self.vllm_config.model_config.is_deepseek_mla: self.tp_num_need_pulls = 1 else: - num_d_block_heads = max(1, - self.num_key_value_heads // self.tp_size) - num_p_block_heads = max( - 1, self.num_key_value_heads // self._prefill_tp_size) + num_d_block_heads = max(1, self.num_key_value_heads // self.tp_size) + num_p_block_heads = max(1, self.num_key_value_heads // self._prefill_tp_size) self.tp_num_need_pulls = num_d_block_heads // num_p_block_heads - self.local_remote_block_port_mapping: dict[ - str, Optional[List[List[int]]]] = {} + self.local_remote_block_port_mapping: dict[str, list[list[int]] | None] = {} self.remote_port_send_num: dict[str, dict[int, RemotePortInfo]] = {} def _get_prefill_decode_size(self, vllm_config: VllmConfig): # get prefill tp and dp size from extra config - prefill_parallel_config: dict[ - str, Any] = vllm_config.kv_transfer_config.get_from_extra_config( - "prefill", {}) + prefill_parallel_config: dict[str, Any] = vllm_config.kv_transfer_config.get_from_extra_config("prefill", {}) - assert "tp_size" in prefill_parallel_config.keys() + assert "tp_size" in prefill_parallel_config self._prefill_tp_size = prefill_parallel_config["tp_size"] - assert "dp_size" in prefill_parallel_config.keys() + assert "dp_size" in prefill_parallel_config self._prefill_dp_size = prefill_parallel_config["dp_size"] # get prefill pp size from extra config self._prefill_pp_size = prefill_parallel_config.get("pp_size", 1) # get decode tp and dp size from extra config - decode_parallel_config: dict[ - str, Any] = vllm_config.kv_transfer_config.get_from_extra_config( - "decode", {}) - assert "tp_size" in decode_parallel_config.keys() + decode_parallel_config: dict[str, Any] = vllm_config.kv_transfer_config.get_from_extra_config("decode", {}) + assert "tp_size" in decode_parallel_config self._decode_tp_size = decode_parallel_config["tp_size"] - assert "dp_size" in decode_parallel_config.keys() + assert "dp_size" in decode_parallel_config self._decode_dp_size = decode_parallel_config["dp_size"] # get prefill pp size from extra config self._decode_pp_size = decode_parallel_config.get("pp_size", 1) assert self._decode_pp_size == 1, "decode pp size must be 1" - self._prefill_pp_layer_partition = prefill_parallel_config.get( - "pp_layer_partition", None) + self._prefill_pp_layer_partition = prefill_parallel_config.get("pp_layer_partition") def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data.""" @@ -1218,9 +1156,9 @@ class MooncakeConnectorWorker: first_kv_cache = first_kv_cache_tuple[0] # TODO(tms): Find a more robust way to detect and handle MLA - self.use_mla = first_kv_cache_tuple[0].size( - -1) != first_kv_cache_tuple[1].size(-1) and len( - first_kv_cache_tuple) == 2 + self.use_mla = ( + first_kv_cache_tuple[0].size(-1) != first_kv_cache_tuple[1].size(-1) and len(first_kv_cache_tuple) == 2 + ) self.use_sparse = len(first_kv_cache_tuple) == 3 if self.use_mla: # MLA case.[num_block, block_size, 1, hidden_dim] @@ -1230,11 +1168,14 @@ class MooncakeConnectorWorker: block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:] self.block_len = [ first_kv_cache[0].element_size() * math.prod(block_shape_norm), - first_kv_cache[1].element_size() * math.prod(block_shape_pe) + first_kv_cache[1].element_size() * math.prod(block_shape_pe), ] logger.info( "num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s", - self.num_blocks, block_shape_norm, block_shape_pe) + self.num_blocks, + block_shape_norm, + block_shape_pe, + ) elif self.use_sparse: self.num_blocks = first_kv_cache.shape[0] block_rank = 3 # [block_size, latent_dim] @@ -1244,27 +1185,32 @@ class MooncakeConnectorWorker: self.block_len = [ first_kv_cache[0].element_size() * math.prod(block_shape_norm), first_kv_cache[1].element_size() * math.prod(block_shape_pe), - first_kv_cache[2].element_size() * math.prod(block_shape_k) + first_kv_cache[2].element_size() * math.prod(block_shape_k), ] logger.info( "num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s, block_shape_k: %s", - self.num_blocks, block_shape_norm, block_shape_pe, - block_shape_k) + self.num_blocks, + block_shape_norm, + block_shape_pe, + block_shape_k, + ) else: # eager:[num_block, block_size, num_head, hidden_dim] # torchair:[num_block, block_size, num_head*hidden_dim] self.num_blocks = first_kv_cache.shape[0] kv_elem_size = first_kv_cache.element_size() - block_rank = len( - first_kv_cache.shape - ) - 1 # [block_size, kv_heads, head_dim] or [block_size, kv_heads*head_dim] + block_rank = ( + len(first_kv_cache.shape) - 1 + ) # [block_size, kv_heads, head_dim] or [block_size, kv_heads*head_dim] block_shape = first_kv_cache.shape[-block_rank:] self.block_len = [kv_elem_size * math.prod(block_shape)] - logger.info("num_blocks: %s, block_shape: %s", self.num_blocks, - block_shape) + logger.info("num_blocks: %s, block_shape: %s", self.num_blocks, block_shape) logger.info( "Registering KV_Caches. use_mla: %s, use_sparse: %s, shape %s", - self.use_mla, self.use_sparse, first_kv_cache.shape) + self.use_mla, + self.use_sparse, + first_kv_cache.shape, + ) self.kv_caches = kv_caches kv_caches_base_addr = [] @@ -1287,9 +1233,7 @@ class MooncakeConnectorWorker: ptrs.append(base_addr) lengths.append(region_len) else: - cache_list = [ - cache_or_caches - ] if self.use_mla or self.use_sparse else cache_or_caches + cache_list = [cache_or_caches] if self.use_mla or self.use_sparse else cache_or_caches for cache in cache_list: base_addr = cache.data_ptr() region_len = self.num_blocks * self.block_len[0] @@ -1308,46 +1252,67 @@ class MooncakeConnectorWorker: self.xfer_handshake_metadata = metadata ready_event = threading.Event() - if self.kv_role == 'kv_producer': + if self.kv_role == "kv_producer": self.kv_send_thread = KVCacheSendingThread( - self.vllm_config, self.tp_rank, self._prefill_tp_size, - self.engine_id, self.side_channel_host, self.side_channel_port, - metadata, ready_event, self.kv_caches, self.pcp_rank) + self.vllm_config, + self.tp_rank, + self._prefill_tp_size, + self.engine_id, + self.side_channel_host, + self.side_channel_port, + metadata, + ready_event, + self.kv_caches, + self.pcp_rank, + ) self.kv_send_thread.start() else: self.kv_recv_thread = KVCacheRecvingThread( - self.tp_rank, self.tp_size, self._prefill_pp_size, self.engine, - self.engine_id, self.handshake_port, self.side_channel_port, - kv_caches_base_addr, self.block_len, ready_event, - self.vllm_config, self.kv_caches, - self._prefill_pp_layer_partition) + self.tp_rank, + self.tp_size, + self._prefill_pp_size, + self.engine, + self.engine_id, + self.handshake_port, + self.side_channel_port, + kv_caches_base_addr, + self.block_len, + ready_event, + self.vllm_config, + self.kv_caches, + self._prefill_pp_layer_partition, + ) self.kv_recv_thread.start() start_wait_time = time.time() - thread = self.kv_send_thread if self.kv_role == 'kv_producer' else self.kv_recv_thread + thread = self.kv_send_thread if self.kv_role == "kv_producer" else self.kv_recv_thread assert thread is not None while not ready_event.is_set(): if not thread.is_alive(): - raise RuntimeError( - "KV Cache sending/receiving thread failed to start.") + raise RuntimeError("KV Cache sending/receiving thread failed to start.") if time.time() - start_wait_time > 5 * 60: - raise RuntimeError( - "Timeout waiting for KV Cache thread to be ready.") + raise RuntimeError("Timeout waiting for KV Cache thread to be ready.") time.sleep(3) def get_finished(self) -> tuple[set[str], set[str]]: done_sending = ( - self.kv_send_thread. - get_and_clear_finished_requests( # type: ignore[union-attr] - ) if self.kv_role == 'kv_producer' else set()) + self.kv_send_thread.get_and_clear_finished_requests( # type: ignore[union-attr] + ) + if self.kv_role == "kv_producer" + else set() + ) done_recving = ( - self.kv_recv_thread. - get_and_clear_finished_requests( # type: ignore[union-attr] - ) if self.kv_role == 'kv_consumer' else set()) + self.kv_recv_thread.get_and_clear_finished_requests( # type: ignore[union-attr] + ) + if self.kv_role == "kv_consumer" + else set() + ) if self.tp_rank == 0: logger.debug( - "Number of completed KV cache send requests: %d, receive " - "requests: %d", len(done_sending), len(done_recving)) + "Number of completed KV cache send requests: %d, receive requests: %d", + len(done_sending), + len(done_recving), + ) return done_sending, done_recving def _get_kv_split_metadata( @@ -1361,22 +1326,15 @@ class MooncakeConnectorWorker: """ if meta.remote_pcp_size * meta.remote_dcp_size * self.pcp_size * self.dcp_size == 1: choosen_rank_list = self._get_remote_rank(req_id) - remote_handshake_port_list = [[ - x + meta.remote_port for x in choosen_rank_list - ]] - local_block_ids_list, remote_block_ids_list = [ - meta.local_block_ids - ], [meta.remote_block_ids] + remote_handshake_port_list = [[x + meta.remote_port for x in choosen_rank_list]] + local_block_ids_list, remote_block_ids_list = [meta.local_block_ids], [meta.remote_block_ids] return remote_handshake_port_list, local_block_ids_list, remote_block_ids_list def context_parallel_parameters_check(): - assert (meta.remote_pcp_size * meta.remote_dcp_size) % ( - self.pcp_size * self.dcp_size) == 0 + assert (meta.remote_pcp_size * meta.remote_dcp_size) % (self.pcp_size * self.dcp_size) == 0 if not self.use_mla: - p_node_heads_per_rank = math.ceil(self.num_key_value_heads / - self._prefill_tp_size) - d_node_heads_per_rank = math.ceil(self.num_key_value_heads / - self.tp_size) + p_node_heads_per_rank = math.ceil(self.num_key_value_heads / self._prefill_tp_size) + d_node_heads_per_rank = math.ceil(self.num_key_value_heads / self.tp_size) assert d_node_heads_per_rank % p_node_heads_per_rank == 0 def get_kv_head_groups(tp_size): @@ -1388,8 +1346,10 @@ class MooncakeConnectorWorker: if self.num_key_value_heads // tp_size >= 1: kv_head_groups = [] for tp_rank in range(tp_size): - kv_head_ids = [head_idx + tp_rank * (self.num_key_value_heads // tp_size) \ - for head_idx in range(self.num_key_value_heads // tp_size)] + kv_head_ids = [ + head_idx + tp_rank * (self.num_key_value_heads // tp_size) + for head_idx in range(self.num_key_value_heads // tp_size) + ] kv_head_groups.append(tuple(kv_head_ids)) return kv_head_groups if tp_size // self.num_key_value_heads > 1: @@ -1407,10 +1367,9 @@ class MooncakeConnectorWorker: for kv_head_group_idx, kv_head_group in enumerate(kv_head_groups): if kv_head_group not in cp_group_meta: cp_group_meta[kv_head_group] = {} - cp_group_meta[kv_head_group]['cp_groups'] = [] - cp_group_meta[kv_head_group]['select_cp_groups_id'] = 0 - kv_head_group_offset = tp_size // len( - kv_head_groups) * kv_head_group_idx + cp_group_meta[kv_head_group]["cp_groups"] = [] + cp_group_meta[kv_head_group]["select_cp_groups_id"] = 0 + kv_head_group_offset = tp_size // len(kv_head_groups) * kv_head_group_idx for dcp_repeat_idx in range(dcp_repeat_num): # len(cp_group) == pcp_size * dcp_size cp_group = [] @@ -1418,38 +1377,33 @@ class MooncakeConnectorWorker: for pcp_rank in range(pcp_size): pcp_rank_offset = tp_size * pcp_rank for dcp_rank in range(dcp_size): - cp_group.append(dcp_rank + port_base + - pcp_rank_offset + - dcp_repeat_offset + - kv_head_group_offset) - cp_group_meta[kv_head_group]['cp_groups'].append(cp_group) + cp_group.append( + dcp_rank + port_base + pcp_rank_offset + dcp_repeat_offset + kv_head_group_offset + ) + cp_group_meta[kv_head_group]["cp_groups"].append(cp_group) return cp_group_meta def get_local_remote_block_port_mappings(): context_parallel_parameters_check() - p_node_cp_group_meta = get_cp_group_meta(self._prefill_tp_size, - meta.remote_pcp_size, - meta.remote_dcp_size, - meta.remote_port) - d_node_cp_group_meta = get_cp_group_meta(self.tp_size, - self.pcp_size, - self.dcp_size, - self.side_channel_port) + p_node_cp_group_meta = get_cp_group_meta( + self._prefill_tp_size, meta.remote_pcp_size, meta.remote_dcp_size, meta.remote_port + ) + d_node_cp_group_meta = get_cp_group_meta(self.tp_size, self.pcp_size, self.dcp_size, self.side_channel_port) local_remote_block_port_mappings: dict[int, list[list[int]]] = {} - for d_node_head_key in d_node_cp_group_meta.keys(): - for p_node_head_key in p_node_cp_group_meta.keys(): + for d_node_head_key in d_node_cp_group_meta: + for p_node_head_key in p_node_cp_group_meta: if not set(p_node_head_key).issubset(set(d_node_head_key)): continue d_node_head_group = d_node_cp_group_meta[d_node_head_key] p_node_head_group = p_node_cp_group_meta[p_node_head_key] - for d_cp_group in d_node_head_group['cp_groups']: - select_cp_groups_id = p_node_head_group[ - 'select_cp_groups_id'] - p_cp_groups = p_node_head_group['cp_groups'] + for d_cp_group in d_node_head_group["cp_groups"]: + select_cp_groups_id = p_node_head_group["select_cp_groups_id"] + p_cp_groups = p_node_head_group["cp_groups"] p_cp_group = p_cp_groups[select_cp_groups_id] - p_node_head_group['select_cp_groups_id'] = select_cp_groups_id + 1 \ - if select_cp_groups_id + 1 < len(p_cp_groups) else 0 + p_node_head_group["select_cp_groups_id"] = ( + select_cp_groups_id + 1 if select_cp_groups_id + 1 < len(p_cp_groups) else 0 + ) for d_idx, d_port in enumerate(d_cp_group): if d_port not in local_remote_block_port_mappings: local_remote_block_port_mappings[d_port] = [] @@ -1457,19 +1411,20 @@ class MooncakeConnectorWorker: for p_idx, p_port in enumerate(p_cp_group): if p_idx % len(d_cp_group) == d_idx: p_port_remote_list.append(p_port) - local_remote_block_port_mappings[d_port].append( - p_port_remote_list) + local_remote_block_port_mappings[d_port].append(p_port_remote_list) logger.info( "p_node_cp_group_meta is:: %s. d_node_cp_group_meta is:: %s. " "local_remote_block_port_mappings is:: %s. ", - p_node_cp_group_meta, d_node_cp_group_meta, - local_remote_block_port_mappings) + p_node_cp_group_meta, + d_node_cp_group_meta, + local_remote_block_port_mappings, + ) return local_remote_block_port_mappings def get_remote_port_send_num( - local_remote_block_port_mappings: dict[int, list[list[int]]] + local_remote_block_port_mappings: dict[int, list[list[int]]], ) -> dict[int, RemotePortInfo]: remote_port_send_num: dict[int, RemotePortInfo] = {} for port in range(self._prefill_tp_size * meta.remote_pcp_size): @@ -1477,59 +1432,49 @@ class MooncakeConnectorWorker: if remote_host_info is None: remote_host = meta.remote_host else: - remote_host = remote_host_info['host'] - remote_port_send_num[meta.remote_port + port] = { - 'num': 0, - 'host': remote_host - } + remote_host = remote_host_info["host"] + remote_port_send_num[meta.remote_port + port] = {"num": 0, "host": remote_host} - for remote_port_head_list in local_remote_block_port_mappings.values( - ): + for remote_port_head_list in local_remote_block_port_mappings.values(): for remote_port_list in remote_port_head_list: for remote_port in remote_port_list: - remote_port_send_num[remote_port]['num'] += 1 + remote_port_send_num[remote_port]["num"] += 1 return remote_port_send_num if meta.remote_engine_id not in self.local_remote_block_port_mapping: self.local_remote_block_port_mapping[meta.remote_engine_id] = None - + if self.local_remote_block_port_mapping[meta.remote_engine_id] is None: - local_remote_block_port_mappings = get_local_remote_block_port_mappings( + local_remote_block_port_mappings = get_local_remote_block_port_mappings() + self.local_remote_block_port_mapping[meta.remote_engine_id] = local_remote_block_port_mappings[ + self.handshake_port + ] + self.remote_port_send_num[meta.remote_engine_id] = get_remote_port_send_num( + local_remote_block_port_mappings ) - self.local_remote_block_port_mapping[ - meta.remote_engine_id] = local_remote_block_port_mappings[ - self.handshake_port] - self.remote_port_send_num[ - meta.remote_engine_id] = get_remote_port_send_num( - local_remote_block_port_mappings) - local_remote_block_port_mapping = copy.deepcopy( - self.local_remote_block_port_mapping[meta.remote_engine_id]) + local_remote_block_port_mapping = copy.deepcopy(self.local_remote_block_port_mapping[meta.remote_engine_id]) - num_external_blocks = math.ceil(meta.num_external_tokens / - self.block_size) + num_external_blocks = math.ceil(meta.num_external_tokens / self.block_size) - assert math.ceil(num_external_blocks / (self.pcp_size * self.dcp_size)) == len(meta.local_block_ids), \ - f"num_external_blocks({num_external_blocks}), cp_size({self.pcp_size * self.dcp_size}), " \ - f"local_block_ids_len ({len(meta.local_block_ids)})" - assert meta.num_prompt_blocks >= num_external_blocks, \ - f"meta.num_prompt_blocks({meta.num_prompt_blocks}), num_external_blocks({num_external_blocks})" + assert math.ceil(num_external_blocks / (self.pcp_size * self.dcp_size)) == len(meta.local_block_ids), ( + f"num_external_blocks({num_external_blocks}), cp_size({self.pcp_size * self.dcp_size}), " + f"local_block_ids_len ({len(meta.local_block_ids)})" + ) + assert meta.num_prompt_blocks >= num_external_blocks, ( + f"meta.num_prompt_blocks({meta.num_prompt_blocks}), num_external_blocks({num_external_blocks})" + ) remote_cp_size = meta.remote_pcp_size * meta.remote_dcp_size - remote_block_nums_all = [meta.num_prompt_blocks // remote_cp_size - ] * remote_cp_size + remote_block_nums_all = [meta.num_prompt_blocks // remote_cp_size] * remote_cp_size num_remain_blocks = meta.num_prompt_blocks % remote_cp_size for i in range(num_remain_blocks): remote_block_nums_all[i] += 1 - last_block_location = (num_remain_blocks + remote_cp_size - - 1) % remote_cp_size + last_block_location = (num_remain_blocks + remote_cp_size - 1) % remote_cp_size # Considering prefix cache, the remote_block_nums_all should be revised num_prefix_cached_blocks = meta.num_prompt_blocks - num_external_blocks - remote_block_nums_all = [ - num - num_prefix_cached_blocks // remote_cp_size - for num in remote_block_nums_all - ] + remote_block_nums_all = [num - num_prefix_cached_blocks // remote_cp_size for num in remote_block_nums_all] num_remain_blocks = num_prefix_cached_blocks % remote_cp_size for i in range(num_remain_blocks): remote_block_nums_all[i] -= 1 @@ -1567,15 +1512,15 @@ class MooncakeConnectorWorker: local_block_offset = 0 for remote_kv_id in range(len(remote_handshake_port_list)): num_blocks_to_pull = remote_block_nums[remote_kv_id] - remote_block_ids_list.append( - meta.remote_block_ids[:num_blocks_to_pull]) + remote_block_ids_list.append(meta.remote_block_ids[:num_blocks_to_pull]) local_block_ids_list.append( - meta.local_block_ids[local_block_offset:local_block_offset + - num_blocks_to_pull]) + meta.local_block_ids[local_block_offset : local_block_offset + num_blocks_to_pull] + ) local_block_offset += num_blocks_to_pull - assert self.tp_num_need_pulls == len(remote_handshake_port_list[0]), \ - f"tp_num_need_pulls: {self.tp_num_need_pulls}, remote_handshake_port_list: {remote_handshake_port_list}" + assert self.tp_num_need_pulls == len(remote_handshake_port_list[0]), ( + f"tp_num_need_pulls: {self.tp_num_need_pulls}, remote_handshake_port_list: {remote_handshake_port_list}" + ) return remote_handshake_port_list, local_block_ids_list, remote_block_ids_list @@ -1584,12 +1529,16 @@ class MooncakeConnectorWorker: for req_id, meta in metadata.requests.items(): logger.debug( "start_load_kv for request %s from remote engine %s. " - "Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id, - meta.remote_engine_id, len(meta.local_block_ids), - len(meta.remote_block_ids)) + "Num local_block_ids: %s. Num remote_block_ids: %s. ", + req_id, + meta.remote_engine_id, + len(meta.local_block_ids), + len(meta.remote_block_ids), + ) if meta.remote_pcp_size * meta.remote_dcp_size > 1: remote_handshake_port_list, local_block_ids_list, remote_block_ids_list = self._get_kv_split_metadata( - req_id, meta) + req_id, meta + ) for pcp_dcp_rank in range(len(remote_handshake_port_list)): for i in range(self.tp_num_need_pulls): @@ -1597,36 +1546,37 @@ class MooncakeConnectorWorker: remote_host, remote_engine_id = self._get_remote_host_info_by_port( meta.remote_port, remote_handshake_port_list[pcp_dcp_rank][i], - meta.remote_host, meta.remote_engine_id, - meta.remote_multi_nodes_meta_mapping) + meta.remote_host, + meta.remote_engine_id, + meta.remote_multi_nodes_meta_mapping, + ) self.kv_recv_thread.add_request( request_id=req_id, remote_request_id=meta.remote_request_id, local_block_ids=local_block_ids_list[pcp_dcp_rank], - remote_block_ids=remote_block_ids_list[ - pcp_dcp_rank], + remote_block_ids=remote_block_ids_list[pcp_dcp_rank], remote_engine_id=remote_engine_id, remote_host=remote_host, - remote_handshake_port=remote_handshake_port_list[ - pcp_dcp_rank][i], + remote_handshake_port=remote_handshake_port_list[pcp_dcp_rank][i], offset=i, tp_num_need_pulls=self.tp_num_need_pulls, - remote_port_send_num=self.remote_port_send_num[ - meta.remote_engine_id], + remote_port_send_num=self.remote_port_send_num[meta.remote_engine_id], all_task_done=( - pcp_dcp_rank - == len(remote_handshake_port_list) - 1 - and i == self.tp_num_need_pulls - 1)) - else: #TODO: support prefill context parallel and pipeline parallel open at the same time + pcp_dcp_rank == len(remote_handshake_port_list) - 1 and i == self.tp_num_need_pulls - 1 + ), + ) + else: # TODO: support prefill context parallel and pipeline parallel open at the same time choosen_rank_list = self._get_remote_rank(req_id) - remote_handshake_port_list = [[x + meta.remote_port] - for x in choosen_rank_list] + remote_handshake_port_list = [[x + meta.remote_port] for x in choosen_rank_list] for i in range(self.tp_num_need_pulls * self._prefill_pp_size): assert self.kv_recv_thread is not None remote_host, remote_engine_id = self._get_remote_host_info_by_port( - meta.remote_port, remote_handshake_port_list[i][0], - meta.remote_host, meta.remote_engine_id, - meta.remote_multi_nodes_meta_mapping) + meta.remote_port, + remote_handshake_port_list[i][0], + meta.remote_host, + meta.remote_engine_id, + meta.remote_multi_nodes_meta_mapping, + ) self.kv_recv_thread.add_request( request_id=req_id, remote_request_id=meta.remote_request_id, @@ -1637,8 +1587,8 @@ class MooncakeConnectorWorker: remote_handshake_port=remote_handshake_port_list[i][0], offset=i, tp_num_need_pulls=self.tp_num_need_pulls, - all_task_done=(i == self.tp_num_need_pulls * - self._prefill_pp_size - 1)) + all_task_done=(i == self.tp_num_need_pulls * self._prefill_pp_size - 1), + ) for req_id in metadata.reqs_in_batch: if self.kv_send_thread is not None: @@ -1649,65 +1599,68 @@ class MooncakeConnectorWorker: if self.kv_send_thread is not None and self.pcp_size * self.dcp_size == 1: for req_id, delay_start_time in metadata.requests_to_send.items(): if self.tp_rank in self._prefill_get_remote_rank(req_id): - self.kv_send_thread.add_delayed_request( - req_id, delay_start_time) + self.kv_send_thread.add_delayed_request(req_id, delay_start_time) else: self.kv_send_thread.add_not_transfer_request(req_id) if self.kv_send_thread is not None and self.pcp_size * self.dcp_size > 1: for req_id, delay_start_time in metadata.requests_to_send.items(): - self.kv_send_thread.add_delayed_request( - req_id, delay_start_time) + self.kv_send_thread.add_delayed_request(req_id, delay_start_time) - def _get_remote_host_info_by_port(self, base_port: int, - remote_handshake_port: int, - remote_host: str, remote_engine_id: str, - remote_multi_nodes_meta_mapping: dict): + def _get_remote_host_info_by_port( + self, + base_port: int, + remote_handshake_port: int, + remote_host: str, + remote_engine_id: str, + remote_multi_nodes_meta_mapping: dict, + ): rank = str(remote_handshake_port - base_port) - if remote_multi_nodes_meta_mapping is None or remote_multi_nodes_meta_mapping.get( - rank, None) is None: + if remote_multi_nodes_meta_mapping is None or remote_multi_nodes_meta_mapping.get(rank) is None: return remote_host, remote_engine_id info = remote_multi_nodes_meta_mapping[rank] - return info.get("host", remote_host), info.get("engine_id", - remote_engine_id) + return info.get("host", remote_host), info.get("engine_id", remote_engine_id) - def _prefill_get_remote_rank(self, req_id: str) -> List[int]: + def _prefill_get_remote_rank(self, req_id: str) -> list[int]: return sum(self._get_remote_ranks_for_req(req_id), []) - def _get_remote_rank(self, req_id: str) -> List[int]: + def _get_remote_rank(self, req_id: str) -> list[int]: return self._get_remote_ranks_for_req(req_id)[self.tp_rank] - def _get_remote_tp_ranks(self, tp_ori_data: np.ndarray, - rand_group_index: list[int], - num_groups: int) -> List[List[int]]: + def _get_remote_tp_ranks( + self, tp_ori_data: np.ndarray, rand_group_index: list[int], num_groups: int + ) -> list[list[int]]: # random split prefill tp list tp_sampled_nums = [] - if self._prefill_tp_size > self.num_key_value_heads or self.vllm_config.model_config.is_deepseek_mla or self.use_sparse: + if ( + self._prefill_tp_size > self.num_key_value_heads + or self.vllm_config.model_config.is_deepseek_mla + or self.use_sparse + ): tp_ori_data = tp_ori_data.reshape(-1, num_groups) choosen_group = tp_ori_data[:, [rand_group_index]] flattened = choosen_group.reshape(-1).tolist() tp_sampled_nums = [ - flattened[i:i + self.tp_num_need_pulls] - for i in range(0, len(flattened), self.tp_num_need_pulls) + flattened[i : i + self.tp_num_need_pulls] for i in range(0, len(flattened), self.tp_num_need_pulls) ] # non-random split else: group_size = self._prefill_tp_size // self._decode_tp_size for i in range(self._decode_tp_size): - slice = tp_ori_data[i * group_size:(i + 1) * group_size] + slice = tp_ori_data[i * group_size : (i + 1) * group_size] tp_sampled_nums.append(slice.tolist()) return tp_sampled_nums - def _get_remote_ranks_for_req(self, req_id: str) -> List[List[int]]: + def _get_remote_ranks_for_req(self, req_id: str) -> list[list[int]]: # Divide the ports according to the TP within the PP sampled_nums = [] if self._prefill_tp_size == self._decode_tp_size: sampled_nums = list( map( - lambda tp: [ - tp + pp * self._prefill_tp_size - for pp in range(self._prefill_pp_size) - ], range(self._prefill_tp_size))) + lambda tp: [tp + pp * self._prefill_tp_size for pp in range(self._prefill_pp_size)], + range(self._prefill_tp_size), + ) + ) return sampled_nums # use deepseek mla, num_key_value_heads == 128, but consider as 1 if self.vllm_config.model_config.is_deepseek_mla or self.use_sparse: @@ -1720,14 +1673,13 @@ class MooncakeConnectorWorker: # random split prefill tp list ori_data = ori_data.reshape(self._prefill_pp_size, -1) num_groups = max( - 1, - len(ori_data[0]) // num_kv_head + 1, len(ori_data[0]) // num_kv_head ) # The number of redundant copies for each KV head within the PP stage - rand_group_index = rand.sample(range(num_groups), \ - (max(self._decode_tp_size // num_kv_head, 1))) # random choose a group + rand_group_index = rand.sample( + range(num_groups), (max(self._decode_tp_size // num_kv_head, 1)) + ) # random choose a group all_results = [ - self._get_remote_tp_ranks(ori_data[pp_index], rand_group_index, - num_groups) + self._get_remote_tp_ranks(ori_data[pp_index], rand_group_index, num_groups) for pp_index in range(self._prefill_pp_size) ] for group_index in range(len(all_results[0])): @@ -1739,28 +1691,24 @@ class MooncakeConnectorWorker: @contextlib.contextmanager -def zmq_ctx(socket_type: Any, - addr: str) -> Iterator[zmq.Socket]: # type: ignore +def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]: # type: ignore """Context manager for a ZMQ socket""" if socket_type not in (zmq.ROUTER, zmq.REQ, zmq.DEALER): # type: ignore raise ValueError(f"Unexpected socket type: {socket_type}") - ctx: Optional[zmq.Context] = None # type: ignore + ctx: zmq.Context | None = None # type: ignore try: ctx = zmq.Context() # type: ignore - yield make_zmq_socket(ctx=ctx, - path=addr, - socket_type=socket_type, - bind=socket_type == zmq.ROUTER) # type: ignore + yield make_zmq_socket(ctx=ctx, path=addr, socket_type=socket_type, bind=socket_type == zmq.ROUTER) # type: ignore finally: if ctx is not None: ctx.destroy(linger=0) def group_concurrent_contiguous( - src: List[int], dst: List[int] -) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]: + src: list[int], dst: list[int] +) -> tuple[list[npt.NDArray[np.int64]], list[npt.NDArray[np.int64]]]: """Vectorised NumPy implementation.""" src_indices: npt.NDArray[np.int64] = np.array(src, dtype=np.int64) dst_indices: npt.NDArray[np.int64] = np.array(dst, dtype=np.int64) @@ -1768,8 +1716,7 @@ def group_concurrent_contiguous( if src_indices.size == 0: return [], [] - brk = np.where((np.diff(src_indices) != 1) - | (np.diff(dst_indices) != 1))[0] + 1 + brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1 src_groups = np.split(src_indices, brk) dst_groups = np.split(dst_indices, brk) @@ -1790,10 +1737,11 @@ def string_to_int64_hash(input_str): def ensure_zmq_send( - socket: zmq.Socket, # type: ignore - data: bytes, - path: str, - max_retries: int = 3): + socket: zmq.Socket, # type: ignore + data: bytes, + path: str, + max_retries: int = 3, +): retries_left = max_retries while True: try: @@ -1802,22 +1750,20 @@ def ensure_zmq_send( except zmq.ZMQError as e: # type: ignore retries_left -= 1 if retries_left > 0: - logger.warning( - f"Send failed: {e}, retrying... ({retries_left} " - "attempts left)") + logger.warning(f"Send failed: {e}, retrying... ({retries_left} attempts left)") time.sleep(0.1) else: logger.error(f"Send failed after all retries: {e}") - raise RuntimeError(f"Failed to send data to {path} after {max_retries} " - f"retries: {e}") + raise RuntimeError(f"Failed to send data to {path} after {max_retries} retries: {e}") def ensure_zmq_recv( - socket: zmq.Socket, # type: ignore - poller: zmq.Poller, # type: ignore - path: str, - timeout: float = 1.0, - max_retries: int = 3) -> bytes: + socket: zmq.Socket, # type: ignore + poller: zmq.Poller, # type: ignore + path: str, + timeout: float = 1.0, + max_retries: int = 3, +) -> bytes: retries_left = max_retries while True: try: @@ -1829,39 +1775,30 @@ def ensure_zmq_recv( except zmq.ZMQError as e: # type: ignore retries_left -= 1 if retries_left > 0: - logger.warning(f"Receive failed: {e}, retrying... " - f"({retries_left} attempts left)") + logger.warning(f"Receive failed: {e}, retrying... ({retries_left} attempts left)") time.sleep(0.1) else: logger.error(f"Receive failed from {path} after all retries: {e}") - raise RuntimeError( - f"Failed to receive data after {max_retries} " - f"retries: {e}") + raise RuntimeError(f"Failed to receive data after {max_retries} retries: {e}") # decode node should know pp_partition_layer in prefill node, # it is configured in kv_transfer_config by partition_list_str, # default using vllm layer split algorithm. def get_prefill_pp_indices( - num_hidden_layers: int, - pp_rank: int, - pp_size: int, - partition_list_str: Optional[str] = None) -> tuple[int, int]: + num_hidden_layers: int, pp_rank: int, pp_size: int, partition_list_str: str | None = None +) -> tuple[int, int]: if partition_list_str is None: return get_pp_indices(num_hidden_layers, pp_rank, pp_size) else: try: - partitions = [ - int(layer) for layer in partition_list_str.split(",") - ] + partitions = [int(layer) for layer in partition_list_str.split(",")] except ValueError as err: - raise ValueError("Invalid partition string: {}".format( - partition_list_str)) from err + raise ValueError("Invalid partition string: {}".format(partition_list_str)) from err if len(partitions) != pp_size: raise ValueError(f"{len(partitions)=} does not match {pp_size=}.") if sum(partitions) != num_hidden_layers: - raise ValueError( - f"{sum(partitions)=} does not match {num_hidden_layers=}.") + raise ValueError(f"{sum(partitions)=} does not match {num_hidden_layers=}.") start_layer = sum(partitions[:pp_rank]) end_layer = start_layer + partitions[pp_rank] - return (start_layer, end_layer) \ No newline at end of file + return (start_layer, end_layer) diff --git a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py index ee161c94..97cb9e89 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py +++ b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py @@ -9,10 +9,10 @@ import struct import threading import time from collections import OrderedDict, defaultdict, deque -from collections.abc import Iterator +from collections.abc import Callable, Iterator from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple +from typing import TYPE_CHECKING, Any import httpx import msgspec @@ -22,22 +22,21 @@ import torch import zmq from mooncake.engine import TransferEngine # type: ignore from vllm.config import VllmConfig -from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) -from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank, - get_tp_group, get_world_group) +from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole +from vllm.distributed.parallel_state import get_tensor_model_parallel_rank, get_tp_group, get_world_group from vllm.logger import logger from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector import \ - GET_META_MSG -from vllm_ascend.distributed.kv_transfer.utils.mooncake_transfer_engine import \ - global_te +from vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector import GET_META_MSG +from vllm_ascend.distributed.kv_transfer.utils.mooncake_transfer_engine import global_te from vllm_ascend.distributed.kv_transfer.utils.utils import ( - align_memory, get_transfer_timeout_value, kv_alltoall_and_rearrange) + align_memory, + get_transfer_timeout_value, + kv_alltoall_and_rearrange, +) from vllm_ascend.utils import npu_stream_switch # isort: off @@ -59,28 +58,28 @@ class MooncakeAgentMetadata(msgspec.Struct, omit_defaults=True, dict=True): @dataclass class ReqMeta: local_block_ids: list[int] - token_ids: Optional[list[int]] + token_ids: list[int] | None # Not None if layer-wise is disabled remote_block_ids: list[int] - remote_engine_id: Optional[str] - remote_host: Optional[str] - remote_port: Optional[int] - remote_te_rpc_port: Optional[int] - remote_kv_caches_base_addr: Optional[list[int]] - metaserver: Optional[str] - chunk_finish: Optional[bool] + remote_engine_id: str | None + remote_host: str | None + remote_port: int | None + remote_te_rpc_port: int | None + remote_kv_caches_base_addr: list[int] | None + metaserver: str | None + chunk_finish: bool | None @dataclass class SendTask: send_request: dict[str, ReqMeta] = field(default_factory=dict) # pd_head_ratio == 1 use - wait_event: Optional[torch.npu.Event] = None + wait_event: torch.npu.Event | None = None # pd_head_ratio > 1 use - k_cache: Optional[torch.Tensor] = None - v_cache: Optional[torch.Tensor] = None + k_cache: torch.Tensor | None = None + v_cache: torch.Tensor | None = None layer_idx: int = 0 - rearrange_block_ids: Optional[list[int]] = None + rearrange_block_ids: list[int] | None = None @dataclass @@ -94,13 +93,13 @@ class TransferMeta: @dataclass class SendReqInfo: local_block_ids: list[int] - remote_block_ids: List[int] + remote_block_ids: list[int] remote_cache_tokens: int local_transferred_tokens: int local_computed_tokens: int request: "Request" - def extend_local_block_ids(self, new_block_ids: List[int]) -> None: + def extend_local_block_ids(self, new_block_ids: list[int]) -> None: """extend local block ids for this step""" self.local_block_ids.extend(new_block_ids) @@ -113,12 +112,18 @@ class SendReqInfo: self.local_transferred_tokens = transferred_tokens def unpack(self): - return self.local_block_ids, self.remote_block_ids, self.remote_cache_tokens, self.local_transferred_tokens, self.local_computed_tokens, self.request + return ( + self.local_block_ids, + self.remote_block_ids, + self.remote_cache_tokens, + self.local_transferred_tokens, + self.local_computed_tokens, + self.request, + ) @dataclass class SizedDict(OrderedDict): - def __init__(self, max_size=16000, *args, **kwargs): self.max_size = max_size super().__init__(*args, **kwargs) @@ -138,7 +143,6 @@ class SizedDict(OrderedDict): class KVCacheSendingLayerThread(threading.Thread): - def __init__( self, engine: TransferEngine, @@ -190,12 +194,9 @@ class KVCacheSendingLayerThread(threading.Thread): try: self._transfer_kv_cache(send_task) except Exception as e: - logger.error( - f"Failed to transfer KV cache for layer idx {send_task.layer_idx}, {e}" - ) + logger.error(f"Failed to transfer KV cache for layer idx {send_task.layer_idx}, {e}") - def get_transfer_meta(self, send_task: SendTask, req_id: str, - req_meta: ReqMeta): + def get_transfer_meta(self, send_task: SendTask, req_id: str, req_meta: ReqMeta): src_list: list[str] = [] dst_list: list[str] = [] length_list: list[int] = [] @@ -209,40 +210,36 @@ class KVCacheSendingLayerThread(threading.Thread): if self.pd_head_ratio == 1: if self.use_sparse: layer_local_kv_base_addr = [ - local_kv_base_addr[i] for i in - [3 * layer_idx, 3 * layer_idx + 1, 3 * layer_idx + 2] + local_kv_base_addr[i] for i in [3 * layer_idx, 3 * layer_idx + 1, 3 * layer_idx + 2] ] layer_remote_kv_base_addr = [ remote_kv_base_addrs[i] # type:ignore - for i in - [3 * layer_idx, 3 * layer_idx + 1, 3 * layer_idx + 2] + for i in [3 * layer_idx, 3 * layer_idx + 1, 3 * layer_idx + 2] ] else: - layer_local_kv_base_addr = [ - local_kv_base_addr[i] - for i in [2 * layer_idx, 2 * layer_idx + 1] - ] + layer_local_kv_base_addr = [local_kv_base_addr[i] for i in [2 * layer_idx, 2 * layer_idx + 1]] layer_remote_kv_base_addr = [ remote_kv_base_addrs[i] # type:ignore for i in [2 * layer_idx, 2 * layer_idx + 1] ] - grouped_remote_block_ids, grouped_local_block_ids = \ - group_concurrent_contiguous(remote_block_ids, local_block_ids) + grouped_remote_block_ids, grouped_local_block_ids = group_concurrent_contiguous( + remote_block_ids, local_block_ids + ) for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate( - zip(layer_local_kv_base_addr, layer_remote_kv_base_addr)): + zip(layer_local_kv_base_addr, layer_remote_kv_base_addr) + ): if self.use_mla: - block_len = (self.block_len[k % 2]) + block_len = self.block_len[k % 2] elif self.use_sparse: - block_len = (self.block_len[k % 3]) + block_len = self.block_len[k % 3] else: - block_len = (self.block_len[0]) + block_len = self.block_len[0] for group_remote_block_id, group_local_block_id in zip( - grouped_remote_block_ids, grouped_local_block_ids): - src = src_layer_base_addr + group_local_block_id[ - 0] * block_len - dst = dst_layer_base_addr + group_remote_block_id[ - 0] * block_len + grouped_remote_block_ids, grouped_local_block_ids + ): + src = src_layer_base_addr + group_local_block_id[0] * block_len + dst = dst_layer_base_addr + group_remote_block_id[0] * block_len length = len(group_local_block_id) * block_len src_list.append(src) dst_list.append(dst) @@ -251,13 +248,9 @@ class KVCacheSendingLayerThread(threading.Thread): rearrange_block_ids = send_task.rearrange_block_ids rearrange_block_dict = { value: index - for index, value in enumerate( - rearrange_block_ids) # type:ignore + for index, value in enumerate(rearrange_block_ids) # type:ignore } - layer_local_kv_base_addr = [ - self.k_buffer.data_ptr(), - self.v_buffer.data_ptr() - ] + layer_local_kv_base_addr = [self.k_buffer.data_ptr(), self.v_buffer.data_ptr()] layer_remote_kv_base_addr = [ remote_kv_base_addrs[i] # type:ignore @@ -266,16 +259,17 @@ class KVCacheSendingLayerThread(threading.Thread): src_list, dst_list, length_list = [], [], [] for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate( - zip(layer_local_kv_base_addr, layer_remote_kv_base_addr)): + zip(layer_local_kv_base_addr, layer_remote_kv_base_addr) + ): block_len = self.block_len[0] remote_block_len = self.block_len[0] * self.pd_head_ratio - for remote_block_id, local_block_id in zip( - remote_block_ids, local_block_ids): - src = src_layer_base_addr + rearrange_block_dict[ - local_block_id] * block_len - dst = dst_layer_base_addr + remote_block_id * remote_block_len + block_len * ( - (self.tp_rank // self.num_head_replica) % - self.pd_head_ratio) + for remote_block_id, local_block_id in zip(remote_block_ids, local_block_ids): + src = src_layer_base_addr + rearrange_block_dict[local_block_id] * block_len + dst = ( + dst_layer_base_addr + + remote_block_id * remote_block_len + + block_len * ((self.tp_rank // self.num_head_replica) % self.pd_head_ratio) + ) src_list.append(src) dst_list.append(dst) length_list.append(block_len) @@ -288,21 +282,17 @@ class KVCacheSendingLayerThread(threading.Thread): value = send_task.v_cache key = key.view(-1, key.shape[-1]) # type:ignore value = value.view(-1, key.shape[-1]) # type:ignore - self.k_buffer[:key.shape[0]].copy_(key) # [:4, 128] -> - self.v_buffer[:value.shape[0]].copy_(value) + self.k_buffer[: key.shape[0]].copy_(key) # [:4, 128] -> + self.v_buffer[: value.shape[0]].copy_(value) # Merge transmission tasks of the same session session_meta: dict[str, TransferMeta] = {} for req_id, req_meta in send_task.send_request.items(): session_id = f"{req_meta.remote_host}:{req_meta.remote_te_rpc_port}" - if session_id not in session_meta.keys(): - session_meta[session_id] = TransferMeta(src=[], - dst=[], - length=[], - req_ids=[]) + if session_id not in session_meta: + session_meta[session_id] = TransferMeta(src=[], dst=[], length=[], req_ids=[]) - (src_list, dst_list, - length_list) = self.get_transfer_meta(send_task, req_id, req_meta) + (src_list, dst_list, length_list) = self.get_transfer_meta(send_task, req_id, req_meta) session_meta[session_id].src.extend(src_list) session_meta[session_id].dst.extend(dst_list) @@ -323,8 +313,8 @@ class KVCacheSendingLayerThread(threading.Thread): for session_id, transfer_meta in session_meta.items(): if len(transfer_meta.src) > 0: ret = self.engine.batch_transfer_sync_write( - session_id, transfer_meta.src, transfer_meta.dst, - transfer_meta.length) + session_id, transfer_meta.src, transfer_meta.dst, transfer_meta.length + ) if ret < 0: logger.error( f"Mooncake transfer failed for send requests {transfer_meta.req_ids} kv cache to {session_id}" @@ -345,11 +335,16 @@ class KVCacheSendingLayerThread(threading.Thread): class KVCacheRecvingLayerThread(threading.Thread): - - def __init__(self, tp_rank: int, side_channel_port: int, tp_size: int, - pd_head_ratio: int, local_engine_id: str, - metadata: MooncakeAgentMetadata, - ready_event: threading.Event): + def __init__( + self, + tp_rank: int, + side_channel_port: int, + tp_size: int, + pd_head_ratio: int, + local_engine_id: str, + metadata: MooncakeAgentMetadata, + ready_event: threading.Event, + ): super().__init__(daemon=True, name="KVCacheRecvingLayerThread") self.tp_rank = tp_rank self.tp_size = tp_size @@ -409,83 +404,69 @@ class KVCacheRecvingLayerThread(threading.Thread): logger.info("Got GET META INFO for request %s", msg[0]) sock.send_multipart((identity, b"", encoded_data)) elif msg[0] == DONE_SENDING_MSG: - logger.debug("Got DONE_RECVING_MSG for request %s", - msg[1]) + logger.debug("Got DONE_RECVING_MSG for request %s", msg[1]) request_id = msg[1] self.update_task(request_id) sock.send_multipart((identity, b"", b"ACK")) else: - logger.error( - "Connection listener got unexpected message %s", - msg) + logger.error("Connection listener got unexpected message %s", msg) except Exception as e: logger.error("Failed to decode message: %s", e) class MooncakeLayerwiseConnectorMetadata(KVConnectorMetadata): - def __init__(self): self.requests: dict[str, ReqMeta] = {} - def add_new_req(self, - request_id: str, - local_block_ids: list[int], - kv_transfer_params: dict[str, Any], - token_ids: Optional[list[int]] = None, - chunk_finish: bool = False): + def add_new_req( + self, + request_id: str, + local_block_ids: list[int], + kv_transfer_params: dict[str, Any], + token_ids: list[int] | None = None, + chunk_finish: bool = False, + ): self.requests[request_id] = ReqMeta( token_ids=token_ids or [], local_block_ids=local_block_ids, remote_block_ids=kv_transfer_params.get("remote_block_ids", []), - remote_engine_id=kv_transfer_params.get("remote_engine_id", None), - remote_host=kv_transfer_params.get("remote_host", None), - remote_port=kv_transfer_params.get("remote_port", None), - remote_te_rpc_port=kv_transfer_params.get("remote_te_rpc_port", - None), - remote_kv_caches_base_addr=kv_transfer_params.get( - "remote_kv_caches_base_addr", None), - metaserver=kv_transfer_params.get("metaserver", None), - chunk_finish=chunk_finish) + remote_engine_id=kv_transfer_params.get("remote_engine_id"), + remote_host=kv_transfer_params.get("remote_host"), + remote_port=kv_transfer_params.get("remote_port"), + remote_te_rpc_port=kv_transfer_params.get("remote_te_rpc_port"), + remote_kv_caches_base_addr=kv_transfer_params.get("remote_kv_caches_base_addr"), + metaserver=kv_transfer_params.get("metaserver"), + chunk_finish=chunk_finish, + ) class MooncakeLayerwiseConnector(KVConnectorBase_V1): - - def __init__(self, - vllm_config: VllmConfig, - role: KVConnectorRole, - kv_cache_config: Optional[KVCacheConfig] = None): + def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole, kv_cache_config: KVCacheConfig | None = None): super().__init__(vllm_config, role, kv_cache_config) assert vllm_config.kv_transfer_config is not None self.engine_id = vllm_config.kv_transfer_config.engine_id self._connector_metadata = MooncakeLayerwiseConnectorMetadata() if role == KVConnectorRole.SCHEDULER: - self.connector_scheduler: Optional[MooncakeLayerwiseConnectorScheduler] = \ - MooncakeLayerwiseConnectorScheduler(vllm_config, str(self.engine_id)) - self.connector_worker: Optional[ - MooncakeLayerwiseConnectorWorker] = None + self.connector_scheduler: MooncakeLayerwiseConnectorScheduler | None = MooncakeLayerwiseConnectorScheduler( + vllm_config, str(self.engine_id) + ) + self.connector_worker: MooncakeLayerwiseConnectorWorker | None = None elif role == KVConnectorRole.WORKER: self.connector_scheduler = None - self.connector_worker = MooncakeLayerwiseConnectorWorker( - vllm_config, str(self.engine_id)) + self.connector_worker = MooncakeLayerwiseConnectorWorker(vllm_config, str(self.engine_id)) ############################################################ # Scheduler Side Methods ############################################################ - def get_num_new_matched_tokens( - self, request: "Request", - num_computed_tokens: int) -> tuple[int, bool]: + def get_num_new_matched_tokens(self, request: "Request", num_computed_tokens: int) -> tuple[int, bool]: assert self.connector_scheduler is not None - return self.connector_scheduler.get_num_new_matched_tokens( - request, num_computed_tokens) + return self.connector_scheduler.get_num_new_matched_tokens(request, num_computed_tokens) - def update_state_after_alloc(self, request: "Request", - blocks: "KVCacheBlocks", - num_external_tokens: int): + def update_state_after_alloc(self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int): assert self.connector_scheduler is not None - return self.connector_scheduler.update_state_after_alloc( - request, blocks, num_external_tokens) + return self.connector_scheduler.update_state_after_alloc(request, blocks, num_external_tokens) def build_connector_meta( self, @@ -498,7 +479,7 @@ class MooncakeLayerwiseConnector(KVConnectorBase_V1): self, request: "Request", block_ids: list[int], - ) -> tuple[bool, Optional[dict[str, Any]]]: + ) -> tuple[bool, dict[str, Any] | None]: assert self.connector_scheduler is not None return self.connector_scheduler.request_finished(request, block_ids) @@ -509,35 +490,29 @@ class MooncakeLayerwiseConnector(KVConnectorBase_V1): assert self.connector_worker is not None self.connector_worker.register_kv_caches(kv_caches) - def get_finished(self, - finished_req_ids: set[str]) -> tuple[set[str], set[str]]: + def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]: """Get the finished recving and sending requests.""" assert self.connector_worker is not None return self.connector_worker.get_finished() - def start_load_kv(self, forward_context: "ForwardContext", - **kwargs) -> None: + def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: assert self.connector_worker is not None - assert isinstance(self._connector_metadata, - MooncakeLayerwiseConnectorMetadata) + assert isinstance(self._connector_metadata, MooncakeLayerwiseConnectorMetadata) self.connector_worker.start_load_kv(self._connector_metadata) def wait_for_layer_load(self, layer_name: str) -> None: """MooncakeLayerwiseConnector does not do layerwise saving.""" assert self.connector_worker is not None - assert isinstance(self._connector_metadata, - MooncakeLayerwiseConnectorMetadata) + assert isinstance(self._connector_metadata, MooncakeLayerwiseConnectorMetadata) self.connector_worker.wait_for_layer_load(layer_name) - def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, - attn_metadata: "AttentionMetadata", **kwargs) -> None: + def save_kv_layer( + self, layer_name: str, kv_layer: torch.Tensor, attn_metadata: "AttentionMetadata", **kwargs + ) -> None: """MooncakeLayerwiseConnector does not save explicitly.""" assert self.connector_worker is not None - assert isinstance(self._connector_metadata, - MooncakeLayerwiseConnectorMetadata) - self.connector_worker.save_kv_layer(layer_name, kv_layer, - attn_metadata, - self._connector_metadata) + assert isinstance(self._connector_metadata, MooncakeLayerwiseConnectorMetadata) + self.connector_worker.save_kv_layer(layer_name, kv_layer, attn_metadata, self._connector_metadata) def wait_for_save(self): """MooncakeLayerwiseConnector does not save explicitly.""" @@ -557,40 +532,32 @@ class MooncakeLayerwiseConnectorScheduler: # Handshake base port self.side_channel_port = ( - vllm_config.kv_transfer_config.kv_port + - vllm_config.parallel_config.data_parallel_rank * - vllm_config.parallel_config.tensor_parallel_size) + vllm_config.kv_transfer_config.kv_port + + vllm_config.parallel_config.data_parallel_rank * vllm_config.parallel_config.tensor_parallel_size + ) # Requests that need to start recv. # New requests are added by update_state_after_alloc in # the scheduler. Used to make metadata passed to Worker. - self._reqs_need_recv: dict[str, tuple[Request, list[int], - list[int]]] = {} + self._reqs_need_recv: dict[str, tuple[Request, list[int], list[int]]] = {} self._reqs_need_send_layerwise: dict[str, SendReqInfo] = {} self.executor = ThreadPoolExecutor(32) - tls_config: dict[ - str, Any] = vllm_config.kv_transfer_config.get_from_extra_config( - "tls_config", {}) - ssl_keyfile = tls_config.get("ssl_keyfile", None) - ssl_certfile = tls_config.get("ssl_certfile", None) + tls_config: dict[str, Any] = vllm_config.kv_transfer_config.get_from_extra_config("tls_config", {}) + ssl_keyfile = tls_config.get("ssl_keyfile") + ssl_certfile = tls_config.get("ssl_certfile") ssl_ca_certs = tls_config.get("ssl_ca_certs", False) - ssl_keyfile_password = tls_config.get("ssl_keyfile_password", None) + ssl_keyfile_password = tls_config.get("ssl_keyfile_password") self.cert_path = (ssl_certfile, ssl_keyfile, ssl_keyfile_password) self.ssl_enable = tls_config.get("ssl_enable", False) self.ca_path = ssl_ca_certs if self.ssl_enable: self.metaserver_client = httpx.Client( - limits=httpx.Limits(max_connections=100000), - timeout=None, - cert=self.cert_path, - verify=self.ca_path) + limits=httpx.Limits(max_connections=100000), timeout=None, cert=self.cert_path, verify=self.ca_path + ) else: - self.metaserver_client = httpx.Client( - limits=httpx.Limits(max_connections=100000), timeout=None) + self.metaserver_client = httpx.Client(limits=httpx.Limits(max_connections=100000), timeout=None) - def get_num_new_matched_tokens( - self, request: "Request", - num_computed_tokens: int) -> tuple[int, bool]: + def get_num_new_matched_tokens(self, request: "Request", num_computed_tokens: int) -> tuple[int, bool]: """ For remote prefill, pull all prompt blocks from remote asynchronously relative to engine execution. @@ -608,9 +575,10 @@ class MooncakeLayerwiseConnectorScheduler: params = request.kv_transfer_params logger.debug( - "MooncakeLayerwiseConnector get_num_new_matched_tokens: " - "num_computed_tokens=%s, kv_transfer_params=%s", - num_computed_tokens, params) + "MooncakeLayerwiseConnector get_num_new_matched_tokens: num_computed_tokens=%s, kv_transfer_params=%s", + num_computed_tokens, + params, + ) if params is not None and params.get("do_remote_prefill"): # Remote prefill: get all prompt blocks from remote. @@ -622,33 +590,29 @@ class MooncakeLayerwiseConnectorScheduler: # No remote prefill for this request. return 0, False - def update_state_after_alloc(self, request: "Request", - blocks: "KVCacheBlocks", - num_external_tokens: int): - + def update_state_after_alloc(self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int): params = request.kv_transfer_params logger.debug( - "MooncakeLayerwiseConnector update_state_after_alloc: " - "num_external_tokens=%s, kv_transfer_params=%s", - num_external_tokens, params) + "MooncakeLayerwiseConnector update_state_after_alloc: num_external_tokens=%s, kv_transfer_params=%s", + num_external_tokens, + params, + ) if params is not None and params.get("do_remote_prefill"): - local_block_ids = (blocks.get_unhashed_block_ids() - if num_external_tokens > 0 else []) + local_block_ids = blocks.get_unhashed_block_ids() if num_external_tokens > 0 else [] # Get unhashed blocks to pull from remote. logger.debug( f"MooncakeLayerwiseConnector update_state_after_alloc: add {request.request_id} to need recv queue" ) self._reqs_need_recv[request.request_id] = ( request, - [], #request._all_token_ids, - local_block_ids) + [], # request._all_token_ids, + local_block_ids, + ) params["do_remote_prefill"] = False - logger.info( - f"Send request: {request.request_id} to proxy metaserver: {params.get('metaserver', None)}" - ) + logger.info(f"Send request: {request.request_id} to proxy metaserver: {params.get('metaserver', None)}") # All parameters here should appear in the returned dict of # request_finished in the scheduler side except "request_id". kv_transfer_params = dict( @@ -662,27 +626,26 @@ class MooncakeLayerwiseConnectorScheduler: remote_port=self.side_channel_port, ) - future = self.executor.submit(self._access_metaserver, - url=params.get("metaserver", None), - message=kv_transfer_params) + future = self.executor.submit( + self._access_metaserver, url=params.get("metaserver", None), message=kv_transfer_params + ) def handle_exception(future): if future.exception(): - logger.error( - f"Access metaserver fail: {future.exception()}") + logger.error(f"Access metaserver fail: {future.exception()}") future.add_done_callback(handle_exception) # Layerwise prefiller add request need send if params is not None and params.get("do_remote_decode"): - local_block_ids = (blocks.get_block_ids()[0]) + local_block_ids = blocks.get_block_ids()[0] logger.debug( f"MooncakeLayerwiseConnector update_state_after_alloc: add {request.request_id} to need send queue" ) remote_block_ids = copy.deepcopy(params["remote_block_ids"]) remote_cache_tokens = ( - (len(request.all_token_ids) + self.block_size - 1) // - self.block_size - len(remote_block_ids)) * self.block_size + (len(request.all_token_ids) + self.block_size - 1) // self.block_size - len(remote_block_ids) + ) * self.block_size local_transferred_tokens = remote_cache_tokens local_computed_tokens = 0 self._reqs_need_send_layerwise[request.request_id] = SendReqInfo( @@ -691,7 +654,8 @@ class MooncakeLayerwiseConnectorScheduler: remote_cache_tokens=remote_cache_tokens, local_transferred_tokens=local_transferred_tokens, local_computed_tokens=local_computed_tokens, - request=request) + request=request, + ) def build_connector_meta( self, @@ -701,16 +665,17 @@ class MooncakeLayerwiseConnectorScheduler: if self.vllm_config.kv_transfer_config.is_kv_consumer: # Loop through scheduled reqs and convert to ReqMeta. - for req_id, (req, token_ids, - block_ids) in self._reqs_need_recv.items(): + for req_id, (req, token_ids, block_ids) in self._reqs_need_recv.items(): assert req.kv_transfer_params is not None # For the case where there are no remote blocks to pull # (block_ids is empty), we don't need to schedule # an async read on the worker side. - meta.add_new_req(request_id=req_id, - local_block_ids=block_ids, - kv_transfer_params=req.kv_transfer_params, - token_ids=token_ids) + meta.add_new_req( + request_id=req_id, + local_block_ids=block_ids, + kv_transfer_params=req.kv_transfer_params, + token_ids=token_ids, + ) # Clear the list once workers start the transfers self._reqs_need_recv.clear() @@ -719,80 +684,77 @@ class MooncakeLayerwiseConnectorScheduler: new_reqs = scheduler_output.scheduled_new_reqs scheduled_spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens # update local block ids - for req_id, new_blocks in zip(cached_reqs.req_ids, - cached_reqs.new_block_ids): + for req_id, new_blocks in zip(cached_reqs.req_ids, cached_reqs.new_block_ids): if req_id in self._reqs_need_send_layerwise and new_blocks is not None: - self._reqs_need_send_layerwise[ - req_id].extend_local_block_ids(new_blocks[0]) + self._reqs_need_send_layerwise[req_id].extend_local_block_ids(new_blocks[0]) computed_tokens = dict( list(zip(cached_reqs.req_ids, cached_reqs.num_computed_tokens)) - + [(x.req_id, x.num_computed_tokens) for x in new_reqs]) - for req_id, scheduled_tokens in scheduler_output.num_scheduled_tokens.items( - ): + + [(x.req_id, x.num_computed_tokens) for x in new_reqs] + ) + for req_id, scheduled_tokens in scheduler_output.num_scheduled_tokens.items(): if req_id in self._reqs_need_send_layerwise: send_req_info = self._reqs_need_send_layerwise[req_id] # update local computed tokens, not transfer spec decode tokens - spec_decode_tokens = len( - scheduled_spec_decode_tokens[req_id]) if ( - req_id in scheduled_spec_decode_tokens) else 0 + spec_decode_tokens = ( + len(scheduled_spec_decode_tokens[req_id]) if (req_id in scheduled_spec_decode_tokens) else 0 + ) send_req_info.update_computed_tokens( - computed_tokens.get(req_id, 0) + scheduled_tokens - - spec_decode_tokens) + computed_tokens.get(req_id, 0) + scheduled_tokens - spec_decode_tokens + ) - def add_tranfer_task(req_id, - send_req_info: SendReqInfo, - chunk_finish=False): - local_block_ids, remote_block_ids, remote_cache_tokens, local_transferred_tokens, local_computed_tokens, request = send_req_info.unpack( + def add_tranfer_task(req_id, send_req_info: SendReqInfo, chunk_finish=False): + ( + local_block_ids, + remote_block_ids, + remote_cache_tokens, + local_transferred_tokens, + local_computed_tokens, + request, + ) = send_req_info.unpack() + local_trans_block_ids = local_block_ids[ + (local_transferred_tokens // self.block_size) : (local_computed_tokens // self.block_size) + ] + remote_trans_block_ids = remote_block_ids[ + ((local_transferred_tokens - remote_cache_tokens) // self.block_size) : ( + (local_computed_tokens - remote_cache_tokens) // self.block_size + ) + ] + request.kv_transfer_params["remote_block_ids"] = remote_trans_block_ids + assert len(local_trans_block_ids) == len(remote_trans_block_ids), ( + f"len of local trans block ids : {len(local_trans_block_ids)} not equal to " + f"the len of remote trans block ids : {len(remote_trans_block_ids)}" + ) + adjusted_tokens = ( + local_computed_tokens - (self.block_size - 1) if chunk_finish else local_computed_tokens ) - local_trans_block_ids = local_block_ids[( - local_transferred_tokens // - self.block_size):(local_computed_tokens // - self.block_size)] - remote_trans_block_ids = remote_block_ids[( - (local_transferred_tokens - remote_cache_tokens) // - self.block_size):((local_computed_tokens - - remote_cache_tokens) // - self.block_size)] - request.kv_transfer_params[ - "remote_block_ids"] = remote_trans_block_ids - assert len(local_trans_block_ids) == len( - remote_trans_block_ids - ), f"len of local trans block ids : {len(local_trans_block_ids)} not equal to the len of remote trans block ids : {len(remote_trans_block_ids)}" - adjusted_tokens = local_computed_tokens - ( - self.block_size - - 1) if chunk_finish else local_computed_tokens logger.info( - f"MooncakeLayerwiseConnector scheduler add transfer task: {req_id=} {local_block_ids=} {remote_block_ids=} {local_trans_block_ids=} {remote_trans_block_ids=} local_computed_tokens={adjusted_tokens} request.all_token_ids={len(request.all_token_ids)}" + f"MooncakeLayerwiseConnector scheduler add transfer task: " + f"{req_id=} {local_block_ids=} {remote_block_ids=} " + f"{local_trans_block_ids=} {remote_trans_block_ids=} " + f"local_computed_tokens={adjusted_tokens} " + f"request.all_token_ids={len(request.all_token_ids)}" ) meta.add_new_req( request_id=req_id, local_block_ids=local_trans_block_ids, kv_transfer_params=request.kv_transfer_params, token_ids=[], - chunk_finish=chunk_finish) + chunk_finish=chunk_finish, + ) # update local_transferred_tokens - local_transferred_tokens = ( - local_computed_tokens // - self.block_size) * self.block_size - send_req_info.update_transferred_tokens( - local_transferred_tokens) + local_transferred_tokens = (local_computed_tokens // self.block_size) * self.block_size + send_req_info.update_transferred_tokens(local_transferred_tokens) # no chunk or last chunk - if send_req_info.local_computed_tokens >= len( - send_req_info.request.all_token_ids): - send_req_info.update_computed_tokens( - send_req_info.local_computed_tokens + - self.block_size - 1) - add_tranfer_task(req_id, - send_req_info, - chunk_finish=True) + if send_req_info.local_computed_tokens >= len(send_req_info.request.all_token_ids): + send_req_info.update_computed_tokens(send_req_info.local_computed_tokens + self.block_size - 1) + add_tranfer_task(req_id, send_req_info, chunk_finish=True) self._reqs_need_send_layerwise.pop(req_id) # chunk - elif (send_req_info.local_computed_tokens // - self.block_size) - ( - send_req_info.local_transferred_tokens // - self.block_size) > 0: + elif (send_req_info.local_computed_tokens // self.block_size) - ( + send_req_info.local_transferred_tokens // self.block_size + ) > 0: add_tranfer_task(req_id, send_req_info) return meta @@ -805,9 +767,7 @@ class MooncakeLayerwiseConnectorScheduler: self.metaserver_client.post(url, json=message) success = True except Exception as e: - logger.error( - f"Failed to connect to metaserver: {url}, retry {retry} time." - ) + logger.error(f"Failed to connect to metaserver: {url}, retry {retry} time.") if retry == 3: raise e @@ -815,7 +775,7 @@ class MooncakeLayerwiseConnectorScheduler: self, request: "Request", block_ids: list[int], - ) -> tuple[bool, Optional[dict[str, Any]]]: + ) -> tuple[bool, dict[str, Any] | None]: """ Once a request is finished, determine whether request blocks should be freed now or will be sent asynchronously and freed later. @@ -829,12 +789,12 @@ class MooncakeLayerwiseConnectorWorker: def __init__(self, vllm_config: VllmConfig, engine_id: str): self._get_prefill_decode_size(vllm_config) - os.environ["ASCEND_TRANSFER_TIMEOUT"] = str( - get_transfer_timeout_value()) + os.environ["ASCEND_TRANSFER_TIMEOUT"] = str(get_transfer_timeout_value()) if self._prefill_tp_size < self._decode_tp_size: raise ValueError( f"prefill_tp_size: {self._prefill_tp_size} must be greater than" - f" or equal to the decode_tp_size: {self._decode_tp_size}") + f" or equal to the decode_tp_size: {self._decode_tp_size}" + ) if TransferEngine is None: raise RuntimeError("mooncake is not available") @@ -849,24 +809,22 @@ class MooncakeLayerwiseConnectorWorker: self.tp_group = get_tp_group() self.kv_caches: dict[str, torch.Tensor] = {} self.side_channel_host = get_ip() - self.total_layers = vllm_config.model_config.get_num_layers( - vllm_config.parallel_config) + self.total_layers = vllm_config.model_config.get_num_layers(vllm_config.parallel_config) # Handshake base port self.side_channel_port = ( - vllm_config.kv_transfer_config.kv_port + - vllm_config.parallel_config.data_parallel_rank * - vllm_config.parallel_config.tensor_parallel_size) + vllm_config.kv_transfer_config.kv_port + + vllm_config.parallel_config.data_parallel_rank * vllm_config.parallel_config.tensor_parallel_size + ) self.handshake_port = self.side_channel_port + self.tp_rank self.sockets: dict = {} logger.info("Initializing Mooncake work %s", engine_id) - self.engine = global_te.get_transfer_engine(self.side_channel_host, - device_name=None) + self.engine = global_te.get_transfer_engine(self.side_channel_host, device_name=None) self.te_rpc_port = self.engine.get_rpc_port() # Background thread for sending or receiving KV caches. - self.kv_recv_layer_thread: Optional[KVCacheRecvingLayerThread] = None - self.kv_send_layer_thread: Optional[KVCacheSendingLayerThread] = None + self.kv_recv_layer_thread: KVCacheRecvingLayerThread | None = None + self.kv_send_layer_thread: KVCacheSendingLayerThread | None = None self.vllm_config = vllm_config self.block_size = vllm_config.cache_config.block_size @@ -883,64 +841,61 @@ class MooncakeLayerwiseConnectorWorker: self.decoder = msgspec.msgpack.Decoder(MooncakeAgentMetadata) self.encoder = msgspec.msgpack.Encoder() - self.remote_kv_caches_base_addr: dict[str, dict[int, list[int]]] = \ - SizedDict() - self.remote_te_port: dict[str, dict[int, int]] = \ - SizedDict() + self.remote_kv_caches_base_addr: dict[str, dict[int, list[int]]] = SizedDict() + self.remote_te_port: dict[str, dict[int, int]] = SizedDict() self.remote_sockets_lock = threading.Lock() self.remote_sockets: dict[ # type: ignore - str, deque[zmq.Socket]] = defaultdict( # type: ignore - deque) + str, deque[zmq.Socket] + ] = defaultdict( # type: ignore + deque + ) self.remote_poller = zmq.Poller() # type: ignore self.timeout = 1.0 # seconds - self.k_buffer: Optional[torch.Tensor] = None - self.v_buffer: Optional[torch.Tensor] = None + self.k_buffer: torch.Tensor | None = None + self.v_buffer: torch.Tensor | None = None def _get_prefill_decode_size(self, vllm_config: VllmConfig): # get prefill tp and dp size from extra config - prefill_parallel_config: dict[ - str, Any] = vllm_config.kv_transfer_config.get_from_extra_config( - "prefill", {}) + prefill_parallel_config: dict[str, Any] = vllm_config.kv_transfer_config.get_from_extra_config("prefill", {}) - assert "tp_size" in prefill_parallel_config.keys() + assert "tp_size" in prefill_parallel_config self._prefill_tp_size = prefill_parallel_config["tp_size"] - assert "dp_size" in prefill_parallel_config.keys() + assert "dp_size" in prefill_parallel_config self._prefill_dp_size = prefill_parallel_config["dp_size"] # get decode tp and dp size from extra config - decode_parallel_config: dict[ - str, Any] = vllm_config.kv_transfer_config.get_from_extra_config( - "decode", {}) - assert "tp_size" in decode_parallel_config.keys() + decode_parallel_config: dict[str, Any] = vllm_config.kv_transfer_config.get_from_extra_config("decode", {}) + assert "tp_size" in decode_parallel_config self._decode_tp_size = decode_parallel_config["tp_size"] - assert "dp_size" in decode_parallel_config.keys() + assert "dp_size" in decode_parallel_config self._decode_dp_size = decode_parallel_config["dp_size"] def create_kv_buffer(self, first_kv_cache): if self.pd_head_ratio > 1: # regesit kv buffer for tp inequal alignment = 2 * 1024 * 1024 - self.k_buffer = torch.zeros(first_kv_cache.numel() + alignment, - dtype=first_kv_cache.dtype, - device=first_kv_cache.device) - self.k_buffer = align_memory( - self.k_buffer, alignment)[:first_kv_cache.numel()].view( - -1, first_kv_cache.shape[-1]) - self.v_buffer = torch.zeros(first_kv_cache.numel() + alignment, - dtype=first_kv_cache.dtype, - device=first_kv_cache.device) - self.v_buffer = align_memory( - self.v_buffer, alignment)[:first_kv_cache.numel()].view( - -1, first_kv_cache.shape[-1]) + self.k_buffer = torch.zeros( + first_kv_cache.numel() + alignment, dtype=first_kv_cache.dtype, device=first_kv_cache.device + ) + self.k_buffer = align_memory(self.k_buffer, alignment)[: first_kv_cache.numel()].view( + -1, first_kv_cache.shape[-1] + ) + self.v_buffer = torch.zeros( + first_kv_cache.numel() + alignment, dtype=first_kv_cache.dtype, device=first_kv_cache.device + ) + self.v_buffer = align_memory(self.v_buffer, alignment)[: first_kv_cache.numel()].view( + -1, first_kv_cache.shape[-1] + ) for tensor in (self.k_buffer, self.v_buffer): - assert tensor.data_ptr( - ) % alignment == 0, "The address of the registered kv cache should be aligned to 2M" - ret_value = self.engine.register_memory( - tensor.data_ptr(), tensor.numel()) + assert tensor.data_ptr() % alignment == 0, ( + "The address of the registered kv cache should be aligned to 2M" + ) + ret_value = self.engine.register_memory(tensor.data_ptr(), tensor.numel()) logger.info( - f"Register memory for prefill when pd head ratio > 1 {tensor.data_ptr()} {tensor.numel()} {ret_value=}" + f"Register memory for prefill when pd head ratio > 1 {tensor.data_ptr()} " + f"{tensor.numel()} {ret_value=}" ) if ret_value != 0: raise RuntimeError("Mooncake memory registration failed. ") @@ -953,9 +908,9 @@ class MooncakeLayerwiseConnectorWorker: self.create_kv_buffer(first_kv_cache) # TODO(tms): Find a more robust way to detect and handle MLA - self.use_mla = first_kv_cache_tuple[0].size( - -1) != first_kv_cache_tuple[1].size(-1) and len( - first_kv_cache_tuple) == 2 + self.use_mla = ( + first_kv_cache_tuple[0].size(-1) != first_kv_cache_tuple[1].size(-1) and len(first_kv_cache_tuple) == 2 + ) self.use_sparse = len(first_kv_cache_tuple) == 3 if self.use_mla: # MLA case.[num_block, block_size, 1, hidden_dim] @@ -965,11 +920,14 @@ class MooncakeLayerwiseConnectorWorker: block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:] self.block_len = [ first_kv_cache[0].element_size() * math.prod(block_shape_norm), - first_kv_cache[1].element_size() * math.prod(block_shape_pe) + first_kv_cache[1].element_size() * math.prod(block_shape_pe), ] logger.info( "num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s", - self.num_blocks, block_shape_norm, block_shape_pe) + self.num_blocks, + block_shape_norm, + block_shape_pe, + ) elif self.use_sparse: self.num_blocks = first_kv_cache.shape[0] block_rank = 3 # [block_size, latent_dim] @@ -979,12 +937,15 @@ class MooncakeLayerwiseConnectorWorker: self.block_len = [ first_kv_cache[0].element_size() * math.prod(block_shape_norm), first_kv_cache[1].element_size() * math.prod(block_shape_pe), - first_kv_cache[2].element_size() * math.prod(block_shape_k) + first_kv_cache[2].element_size() * math.prod(block_shape_k), ] logger.info( "num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s, block_shape_k: %s", - self.num_blocks, block_shape_norm, block_shape_pe, - block_shape_k) + self.num_blocks, + block_shape_norm, + block_shape_pe, + block_shape_k, + ) else: # [num_block, block_size, num_head, hidden_dim] self.num_blocks = first_kv_cache.shape[0] @@ -992,12 +953,14 @@ class MooncakeLayerwiseConnectorWorker: block_rank = 3 # [block_size, kv_heads, head_dim] block_shape = first_kv_cache.shape[-block_rank:] self.block_len = [kv_elem_size * math.prod(block_shape)] - logger.info("num_blocks: %s, block_shape: %s", self.num_blocks, - block_shape) + logger.info("num_blocks: %s, block_shape: %s", self.num_blocks, block_shape) logger.info( "Registering KV_Caches. use_mla: %s, use_sparse: %s, shape %s", - self.use_mla, self.use_sparse, first_kv_cache.shape) + self.use_mla, + self.use_sparse, + first_kv_cache.shape, + ) self.kv_caches = kv_caches kv_caches_base_addr = [] @@ -1020,9 +983,7 @@ class MooncakeLayerwiseConnectorWorker: ptrs.append(base_addr) lengths.append(region_len) else: - cache_list = [ - cache_or_caches - ] if self.use_mla or self.use_sparse else cache_or_caches + cache_list = [cache_or_caches] if self.use_mla or self.use_sparse else cache_or_caches for cache in cache_list: base_addr = cache.data_ptr() region_len = self.num_blocks * self.block_len[0] @@ -1054,27 +1015,34 @@ class MooncakeLayerwiseConnectorWorker: k_buffer=self.k_buffer, v_buffer=self.v_buffer, resharding_stream=self.resharding_stream, - callback_func=self.send_done_send_signal) + callback_func=self.send_done_send_signal, + ) self.kv_send_layer_thread.start() ready_event.wait() if self.vllm_config.kv_transfer_config.is_kv_consumer: ready_event = threading.Event() self.kv_recv_layer_thread = KVCacheRecvingLayerThread( - self.tp_rank, self.side_channel_port, self.tp_size, - self.pd_head_ratio, self.engine_id, metadata, ready_event) + self.tp_rank, + self.side_channel_port, + self.tp_size, + self.pd_head_ratio, + self.engine_id, + metadata, + ready_event, + ) self.kv_recv_layer_thread.start() ready_event.wait() def get_finished(self) -> tuple[set[str], set[str]]: done_recving = ( - self.kv_recv_layer_thread. - get_and_clear_finished_requests( # type: ignore[union-attr] - ) if self.vllm_config.kv_transfer_config.is_kv_consumer else set()) + self.kv_recv_layer_thread.get_and_clear_finished_requests( # type: ignore[union-attr] + ) + if self.vllm_config.kv_transfer_config.is_kv_consumer + else set() + ) if len(done_recving) > 0: - logger.info( - "Number of completed KV cache recv requests: %d, receive " - "requests: %d", 0, len(done_recving)) + logger.info("Number of completed KV cache recv requests: %d, receive requests: %d", 0, len(done_recving)) return set(), done_recving def start_load_kv(self, metadata: MooncakeLayerwiseConnectorMetadata): @@ -1086,14 +1054,16 @@ class MooncakeLayerwiseConnectorWorker: with self.kv_recv_layer_thread.lock: self.kv_recv_layer_thread.task_tracker[req_id] = 0 - def save_kv_layer(self, layer_name: str, kv_layer: Tuple[torch.Tensor, - torch.Tensor], - attn_metadata: "AttentionMetadata", - connector_metadata: MooncakeLayerwiseConnectorMetadata, - **kwargs) -> None: + def save_kv_layer( + self, + layer_name: str, + kv_layer: tuple[torch.Tensor, torch.Tensor], + attn_metadata: "AttentionMetadata", + connector_metadata: MooncakeLayerwiseConnectorMetadata, + **kwargs, + ) -> None: """MooncakeLayerwiseConnector does not save explicitly.""" - if self.vllm_config.kv_transfer_config.is_kv_producer and connector_metadata.requests.keys( - ): + if self.vllm_config.kv_transfer_config.is_kv_producer and connector_metadata.requests.keys(): # enable decode prefix cache if self.use_mla or self.use_sparse: num_need_send = self._decode_tp_size @@ -1107,13 +1077,11 @@ class MooncakeLayerwiseConnectorWorker: replica_group_idx = self.tp_rank % num_replica_groups req_ids = sorted(list(connector_metadata.requests.keys())) selected_req_ids = [ - req_id for i, req_id in enumerate(req_ids) - if i % num_replica_groups == replica_group_idx + req_id for i, req_id in enumerate(req_ids) if i % num_replica_groups == replica_group_idx ] if selected_req_ids: if self.use_mla or self.use_sparse: - reshape_cache_event = attn_metadata[ - layer_name].reshape_cache_event + reshape_cache_event = attn_metadata[layer_name].reshape_cache_event else: reshape_cache_event = attn_metadata.reshape_cache_event @@ -1121,28 +1089,31 @@ class MooncakeLayerwiseConnectorWorker: assert self.resharding_stream is not None with npu_stream_switch(self.resharding_stream): reshape_cache_event.wait() - rearrange_block_ids = sorted({ - block_id - for req_id in selected_req_ids - for block_id in - connector_metadata.requests[req_id].local_block_ids - }) + rearrange_block_ids = sorted( + { + block_id + for req_id in selected_req_ids + for block_id in connector_metadata.requests[req_id].local_block_ids + } + ) keys = kv_layer[0][rearrange_block_ids].clone() values = kv_layer[1][rearrange_block_ids].clone() # sort kv caches for each block - keys = keys.view(keys.size(0), self.pd_head_ratio, -1, - *keys.shape[2:]).transpose( - 0, 1).reshape_as(keys) - values = values.view(values.size(0), - self.pd_head_ratio, -1, - *values.shape[2:]).transpose( - 0, 1).reshape_as(values) + keys = ( + keys.view(keys.size(0), self.pd_head_ratio, -1, *keys.shape[2:]) + .transpose(0, 1) + .reshape_as(keys) + ) + values = ( + values.view(values.size(0), self.pd_head_ratio, -1, *values.shape[2:]) + .transpose(0, 1) + .reshape_as(values) + ) # reshard kv cache keys = keys.reshape(-1, *kv_layer[0].shape[2:]) values = values.reshape(-1, *kv_layer[1].shape[2:]) - (keys, values) = kv_alltoall_and_rearrange( - self.pd_head_ratio, keys, values) + (keys, values) = kv_alltoall_and_rearrange(self.pd_head_ratio, keys, values) else: keys = None values = None @@ -1150,26 +1121,23 @@ class MooncakeLayerwiseConnectorWorker: assert self.kv_send_layer_thread is not None assert reshape_cache_event is not None - send_task = SendTask(wait_event=reshape_cache_event, - k_cache=keys, - v_cache=values, - layer_idx=self.current_layer, - rearrange_block_ids=rearrange_block_ids) + send_task = SendTask( + wait_event=reshape_cache_event, + k_cache=keys, + v_cache=values, + layer_idx=self.current_layer, + rearrange_block_ids=rearrange_block_ids, + ) for req_id, req_meta in connector_metadata.requests.items(): if req_id in selected_req_ids: - req_meta_update = self.update_decoder_info( - req_id, req_meta) - logger.debug( - f"Add request {req_id} to kv send layer thread. {req_meta_update=}" - ) + req_meta_update = self.update_decoder_info(req_id, req_meta) + logger.debug(f"Add request {req_id} to kv send layer thread. {req_meta_update=}") send_task.send_request[req_id] = req_meta_update self.kv_send_layer_thread.send_queue.put(send_task) self.current_layer += 1 - def _get_remote_socket( - self, remote_host: str, - remote_handshake_port: int) -> zmq.Socket: # type: ignore + def _get_remote_socket(self, remote_host: str, remote_handshake_port: int) -> zmq.Socket: # type: ignore """Get a socket to the remote host.""" remote_path = make_zmq_path("tcp", remote_host, remote_handshake_port) with self.remote_sockets_lock: @@ -1181,10 +1149,12 @@ class MooncakeLayerwiseConnectorWorker: ctx=ctx, path=remote_path, socket_type=zmq.REQ, # type: ignore - bind=False) + bind=False, + ) sock.setsockopt( zmq.SNDTIMEO, # type: ignore - int(self.timeout * 1000)) + int(self.timeout * 1000), + ) self.remote_poller.register(sock, zmq.POLLIN) # type: ignore return sock @@ -1192,56 +1162,62 @@ class MooncakeLayerwiseConnectorWorker: req_meta_update = copy.deepcopy(req_meta) if self.use_mla or self.use_sparse: pd_tp_ratio = self.tp_size // self._decode_tp_size - req_meta_update.remote_port = req_meta_update.remote_port + ( - self.tp_rank // pd_tp_ratio) % self._decode_tp_size + req_meta_update.remote_port = ( + req_meta_update.remote_port + (self.tp_rank // pd_tp_ratio) % self._decode_tp_size + ) else: - req_meta_update.remote_port = req_meta_update.remote_port + ( - self.tp_rank // self.pd_tp_ratio) % self._decode_tp_size - if req_meta_update.remote_engine_id not in self.remote_kv_caches_base_addr or \ - req_meta_update.remote_port not in self.remote_kv_caches_base_addr[req_meta_update.remote_engine_id]: + req_meta_update.remote_port = ( + req_meta_update.remote_port + (self.tp_rank // self.pd_tp_ratio) % self._decode_tp_size + ) + if ( + req_meta_update.remote_engine_id not in self.remote_kv_caches_base_addr + or req_meta_update.remote_port not in self.remote_kv_caches_base_addr[req_meta_update.remote_engine_id] + ): try: encoded_data = self.encoder.encode((GET_META_MSG, req_id)) - sock = self._get_remote_socket(req_meta_update.remote_host, - req_meta_update.remote_port) + sock = self._get_remote_socket(req_meta_update.remote_host, req_meta_update.remote_port) ensure_zmq_send(sock, encoded_data) metadata_bytes = ensure_zmq_recv(sock, self.remote_poller) agent_meta = self.decoder.decode(metadata_bytes) except Exception as e: logger.error( - f"Query to port and kv base addr for request {req_id} from {req_meta_update.remote_host}:{req_meta_update.remote_port} fail with error: {e}" + f"Query to port and kv base addr for request {req_id} from " + f"{req_meta_update.remote_host}:{req_meta_update.remote_port} fail with error: {e}" ) assert req_meta_update.remote_engine_id != self.engine_id, ( - f"Conflict engine id {req_meta_update.remote_engine_id} with local engine id " - f"{self.local_engine_id}.") - self.remote_kv_caches_base_addr[req_meta_update.remote_engine_id][ - req_meta_update.remote_port] = agent_meta.kv_caches_base_addr - self.remote_te_port[req_meta_update.remote_engine_id][ - req_meta_update.remote_port] = agent_meta.te_rpc_port + f"Conflict engine id {req_meta_update.remote_engine_id} with local engine id {self.local_engine_id}." + ) + self.remote_kv_caches_base_addr[req_meta_update.remote_engine_id][req_meta_update.remote_port] = ( + agent_meta.kv_caches_base_addr + ) + self.remote_te_port[req_meta_update.remote_engine_id][req_meta_update.remote_port] = agent_meta.te_rpc_port logger.info( - f"Query to port and kv base addr for request {req_id} from {req_meta_update.remote_host}:{req_meta_update.remote_port} success {agent_meta.kv_caches_base_addr=} {agent_meta.te_rpc_port=}" + f"Query to port and kv base addr for request {req_id} from " + f"{req_meta_update.remote_host}:{req_meta_update.remote_port} success " + f"{agent_meta.kv_caches_base_addr=} {agent_meta.te_rpc_port=}" ) if self.pd_head_ratio > 1: # for tp inequal, pre-create link to prevent alltoall out of memory session_id = f"{req_meta_update.remote_host}:{agent_meta.te_rpc_port}" ret = self.engine.batch_transfer_sync_write( - session_id, [self.kv_caches_base_addr[0]], - [agent_meta.kv_caches_base_addr[0]], [128]) + session_id, [self.kv_caches_base_addr[0]], [agent_meta.kv_caches_base_addr[0]], [128] + ) if ret < 0: - logger.error( - f"Mooncake transfer failed to create link to device {session_id}" - ) - req_meta_update.remote_te_rpc_port = self.remote_te_port[ - req_meta_update.remote_engine_id][req_meta_update.remote_port] - req_meta_update.remote_kv_caches_base_addr = self.remote_kv_caches_base_addr[ - req_meta_update.remote_engine_id][req_meta_update.remote_port] + logger.error(f"Mooncake transfer failed to create link to device {session_id}") + req_meta_update.remote_te_rpc_port = self.remote_te_port[req_meta_update.remote_engine_id][ + req_meta_update.remote_port + ] + req_meta_update.remote_kv_caches_base_addr = self.remote_kv_caches_base_addr[req_meta_update.remote_engine_id][ + req_meta_update.remote_port + ] return req_meta_update def send_done_send_signal(self, req_id, req_meta): - logger.info("Sending done sending signal for request %s to %s:%d", - req_id, req_meta.remote_host, req_meta.remote_port) + logger.info( + "Sending done sending signal for request %s to %s:%d", req_id, req_meta.remote_host, req_meta.remote_port + ) try: - path = make_zmq_path("tcp", req_meta.remote_host, - req_meta.remote_port) + path = make_zmq_path("tcp", req_meta.remote_host, req_meta.remote_port) msg_encoder = msgspec.msgpack.Encoder() encoded_data = msg_encoder.encode((DONE_SENDING_MSG, req_id)) with zmq_ctx(zmq.REQ, path) as sock: # type: ignore @@ -1251,7 +1227,8 @@ class MooncakeLayerwiseConnectorWorker: raise ValueError(f"Unexpected ACK response: {ack}") except Exception as e: logger.error( - f"Sending done sending signal for request {req_id} to {req_meta.remote_host}:{req_meta.remote_port} fail with error: {e}" + f"Sending done sending signal for request {req_id} to " + f"{req_meta.remote_host}:{req_meta.remote_port} fail with error: {e}" ) def wait_for_layer_load(self, layer_name: str) -> None: @@ -1259,37 +1236,34 @@ class MooncakeLayerwiseConnectorWorker: @contextlib.contextmanager -def zmq_ctx(socket_type: Any, - addr: str) -> Iterator[zmq.Socket]: # type: ignore +def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]: # type: ignore """Context manager for a ZMQ socket""" if socket_type not in (zmq.ROUTER, zmq.REQ, zmq.DEALER): # type: ignore raise ValueError(f"Unexpected socket type: {socket_type}") - ctx: Optional[zmq.Context] = None # type: ignore + ctx: zmq.Context | None = None # type: ignore try: ctx = zmq.Context() # type: ignore - yield make_zmq_socket(ctx=ctx, - path=addr, - socket_type=socket_type, - bind=socket_type == zmq.ROUTER) # type: ignore + yield make_zmq_socket(ctx=ctx, path=addr, socket_type=socket_type, bind=socket_type == zmq.ROUTER) # type: ignore finally: if ctx is not None: ctx.destroy(linger=0) def group_concurrent_contiguous( - src: List[int], - dst: List[int] = [] -) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]: + src: list[int], dst: list[int] | None = None +) -> tuple[list[npt.NDArray[np.int64]], list[npt.NDArray[np.int64]]]: """Vectorised NumPy implementation.""" + if dst is None: + dst = [] if not dst: src_only_indices: npt.NDArray[np.int64] = np.array(src, dtype=np.int64) if src_only_indices.size == 0: return [], [] - brk = np.where((np.diff(src_only_indices) != 1))[0] + 1 + brk = np.where(np.diff(src_only_indices) != 1)[0] + 1 src_groups = np.split(src_only_indices, brk) src_groups = [g.tolist() for g in src_groups] @@ -1302,8 +1276,7 @@ def group_concurrent_contiguous( if src_indices.size == 0: return [], [] - brk = np.where((np.diff(src_indices) != 1) - | (np.diff(dst_indices) != 1))[0] + 1 + brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1 src_groups = np.split(src_indices, brk) dst_groups = np.split(dst_indices, brk) @@ -1324,9 +1297,10 @@ def string_to_int64_hash(input_str): def ensure_zmq_send( - socket: zmq.Socket, # type: ignore - data: bytes, - max_retries: int = 3): + socket: zmq.Socket, # type: ignore + data: bytes, + max_retries: int = 3, +): retries_left = max_retries while True: try: @@ -1335,21 +1309,19 @@ def ensure_zmq_send( except zmq.ZMQError as e: # type: ignore retries_left -= 1 if retries_left > 0: - logger.warning( - f"Send failed: {e}, retrying... ({retries_left} " - "attempts left)") + logger.warning(f"Send failed: {e}, retrying... ({retries_left} attempts left)") time.sleep(0.1) else: logger.error(f"Send failed after all retries: {e}") - raise RuntimeError(f"Failed to send data after {max_retries} " - f"retries: {e}") + raise RuntimeError(f"Failed to send data after {max_retries} retries: {e}") def ensure_zmq_recv( - socket: zmq.Socket, # type: ignore - poller: zmq.Poller, # type: ignore - timeout: float = 1.0, - max_retries: int = 3) -> bytes: + socket: zmq.Socket, # type: ignore + poller: zmq.Poller, # type: ignore + timeout: float = 1.0, + max_retries: int = 3, +) -> bytes: retries_left = max_retries while True: try: @@ -1361,11 +1333,8 @@ def ensure_zmq_recv( except zmq.ZMQError as e: # type: ignore retries_left -= 1 if retries_left > 0: - logger.warning(f"Receive failed: {e}, retrying... " - f"({retries_left} attempts left)") + logger.warning(f"Receive failed: {e}, retrying... ({retries_left} attempts left)") time.sleep(0.1) else: logger.error(f"Receive failed after all retries: {e}") - raise RuntimeError( - f"Failed to receive data after {max_retries} " - f"retries: {e}") + raise RuntimeError(f"Failed to receive data after {max_retries} retries: {e}")