### What this PR does / why we need it?
pd disaggregated support cross-machine.
We send the primary and secondary node information of node p to node d.
When node d pulls the KV data, it retrieves the corresponding primary or
secondary node information from the mapping.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
1547 lines
68 KiB
Python
1547 lines
68 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
import contextlib
|
|
import hashlib
|
|
import math
|
|
import os
|
|
import queue
|
|
import random
|
|
import struct
|
|
import threading
|
|
import time
|
|
from collections import defaultdict, deque
|
|
from collections.abc import Iterator
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from dataclasses import dataclass
|
|
from typing import TYPE_CHECKING, Any, List, Optional, OrderedDict, Tuple
|
|
|
|
import msgspec
|
|
import numpy as np
|
|
import numpy.typing as npt
|
|
import torch
|
|
import torch_npu
|
|
import zmq
|
|
from mooncake.engine import TransferEngine # type: ignore
|
|
from vllm import envs
|
|
from vllm.config import VllmConfig
|
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
|
KVConnectorBase_V1, KVConnectorHandshakeMetadata, KVConnectorMetadata,
|
|
KVConnectorRole)
|
|
from vllm.distributed.parallel_state import (
|
|
get_decode_context_model_parallel_rank,
|
|
get_decode_context_model_parallel_world_size, get_pp_group,
|
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
|
|
get_tp_group)
|
|
from vllm.distributed.utils import get_pp_indices
|
|
from vllm.logger import logger
|
|
from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket
|
|
from vllm.v1.core.sched.output import SchedulerOutput
|
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
|
from vllm.v1.request import RequestStatus
|
|
|
|
from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config
|
|
from vllm_ascend.distributed.mooncake_transfer_engine import global_te
|
|
from vllm_ascend.distributed.utils import get_transfer_timeout_value
|
|
from vllm_ascend.utils import prefill_context_parallel_enable
|
|
|
|
# isort: off
|
|
if prefill_context_parallel_enable():
|
|
from vllm.distributed import (get_prefill_context_model_parallel_rank,
|
|
get_prefill_context_model_parallel_world_size
|
|
)
|
|
# isort: on
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.attention.backends.abstract import AttentionMetadata
|
|
from vllm.forward_context import ForwardContext
|
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
|
from vllm.v1.request import Request
|
|
|
|
GET_META_MSG = b"get_meta_msg"
|
|
DONE_RECVING_MSG = b"done_recving_msg"
|
|
|
|
|
|
class MooncakeAgentMetadata(msgspec.Struct, omit_defaults=True, dict=True):
|
|
engine_id: str
|
|
te_rpc_port: int
|
|
kv_caches_base_addr: list[int]
|
|
num_blocks: int
|
|
local_ip: str = ""
|
|
|
|
|
|
@dataclass
|
|
class ReqMeta:
|
|
local_block_ids: list[int]
|
|
remote_block_ids: list[int]
|
|
remote_host: str
|
|
remote_port: int
|
|
remote_engine_id: str
|
|
remote_pcp_size: int
|
|
remote_dcp_size: int
|
|
remote_multi_nodes_meta_mapping: dict[str, dict[str, Any]]
|
|
|
|
|
|
@dataclass
|
|
class SizedDict(OrderedDict):
|
|
|
|
def __init__(self, max_size=16000, *args, **kwargs):
|
|
self.max_size = max_size
|
|
super().__init__(*args, **kwargs)
|
|
|
|
def __setitem__(self, key, value):
|
|
super().__setitem__(key, value)
|
|
if len(self) > self.max_size:
|
|
self.popitem(last=False)
|
|
|
|
def __getitem__(self, key):
|
|
try:
|
|
return super().__getitem__(key)
|
|
except KeyError:
|
|
value: dict[int, list[int]] = {}
|
|
self[key] = value
|
|
return value
|
|
|
|
|
|
class KVCacheTaskTracker:
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
self.done_task_lock = threading.Lock()
|
|
self.finished_requests: set[str] = set()
|
|
# Only used in prefill node. Tracks requests whose kv blocks freeing is
|
|
# intentionally delayed. Each entry is a tuple of (request_id,
|
|
# timestamp). If a request remains in this queue for too long, it will
|
|
# be force-freed.
|
|
self.record_finished_requests: set[str] = set()
|
|
self.delayed_free_requests: OrderedDict[str, float] = OrderedDict()
|
|
|
|
def add_not_transfer_request(self, request_id: str):
|
|
with self.done_task_lock:
|
|
self.finished_requests.add(request_id)
|
|
|
|
def update_done_task_count(self, request_id: str):
|
|
with self.done_task_lock:
|
|
self.finished_requests.add(request_id)
|
|
if request_id in self.delayed_free_requests:
|
|
self._remove_delayed_requests(request_id)
|
|
else:
|
|
self.record_finished_requests.add(request_id)
|
|
|
|
def get_and_clear_finished_requests(self) -> set[str]:
|
|
"""
|
|
Get and clear the requests that have been completed.
|
|
Returns:
|
|
A set of request IDs that have been completed.
|
|
"""
|
|
with self.done_task_lock:
|
|
finished_requests = self.finished_requests.copy()
|
|
expired_requests = self._retrieve_expired_requests()
|
|
finished_requests.update(expired_requests)
|
|
self.finished_requests.clear()
|
|
return finished_requests
|
|
|
|
def add_delayed_request(self, request_id: str, delay_start_time: float):
|
|
"""Add a delayed free request."""
|
|
with self.done_task_lock:
|
|
if request_id not in self.record_finished_requests:
|
|
self.delayed_free_requests[request_id] = delay_start_time
|
|
else:
|
|
self.record_finished_requests.discard(request_id)
|
|
|
|
def _retrieve_expired_requests(self):
|
|
"""Retrieve all expired delayed requests."""
|
|
expired_requests: set[str] = set()
|
|
# Free delayed requests if they exceed the timeout
|
|
current_time = time.time()
|
|
while self.delayed_free_requests:
|
|
request_id = next(iter(self.delayed_free_requests))
|
|
delay_start_time = self.delayed_free_requests[request_id]
|
|
if (current_time - delay_start_time
|
|
> envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT):
|
|
self.delayed_free_requests.popitem(last=False)
|
|
expired_requests.add(request_id)
|
|
logger.info("Force freed request: %s", request_id)
|
|
else:
|
|
break
|
|
return expired_requests
|
|
|
|
def _remove_delayed_requests(self, request_id: str):
|
|
"""Remove all delayed free requests matching the given request_id."""
|
|
self.delayed_free_requests.pop(request_id)
|
|
|
|
|
|
class KVCacheSendingThread(threading.Thread):
|
|
|
|
def __init__(self, vllm_config: VllmConfig, tp_rank: int,
|
|
prefill_tp_size: int, local_engine_id: str,
|
|
side_channel_host: str, side_channel_port: int,
|
|
metadata: MooncakeAgentMetadata, ready_event: threading.Event,
|
|
kv_caches: dict[str, Any], pcp_rank: int):
|
|
super().__init__(daemon=True, name="KVCacheSendingThread")
|
|
self.tp_rank = tp_rank
|
|
self.prefill_tp_size = prefill_tp_size
|
|
self.pp_rank = get_pp_group().rank_in_group
|
|
self.pp_size = vllm_config.parallel_config.pipeline_parallel_size
|
|
self.tp_size = get_tensor_model_parallel_world_size()
|
|
self.local_engine_id = local_engine_id
|
|
self.side_channel_host = side_channel_host
|
|
self.side_channel_port = side_channel_port
|
|
self.metadata = metadata
|
|
self.ready_event = ready_event
|
|
self.kv_caches = kv_caches
|
|
self.pcp_rank = pcp_rank
|
|
|
|
self.task_tracker = KVCacheTaskTracker()
|
|
|
|
def get_and_clear_finished_requests(self) -> set[str]:
|
|
"""
|
|
Get and clear the requests that have been completed.
|
|
Returns:
|
|
A set of request IDs that have been completed.
|
|
"""
|
|
return self.task_tracker.get_and_clear_finished_requests()
|
|
|
|
def add_not_transfer_request(self, request_id: str):
|
|
self.task_tracker.add_not_transfer_request(request_id)
|
|
|
|
def add_delayed_request(self, request_id: str, delay_start_time: float):
|
|
return self.task_tracker.add_delayed_request(request_id,
|
|
delay_start_time)
|
|
|
|
def run(self):
|
|
"""Run the thread to handle KV cache transfer requests."""
|
|
try:
|
|
# Listen for new requests for metadata. NOTE(rob): we need each rank
|
|
# to have a unique port. This hack to keeps us moving. We will
|
|
# switch when moving to etcd or where we have a single ZMQ socket in
|
|
# the scheduler.
|
|
device_index = self.pp_rank * self.tp_size + self.tp_rank + self.pcp_rank * self.prefill_tp_size
|
|
handshake_port = self.side_channel_port + device_index
|
|
path = make_zmq_path("tcp", self.side_channel_host, handshake_port)
|
|
logger.info("Starting listening on path: %s", path)
|
|
with zmq_ctx(zmq.ROUTER, path) as sock: # type: ignore
|
|
self.ready_event.set()
|
|
self.run_busy_loop(sock)
|
|
except Exception as e:
|
|
logger.error("Mooncake KVCacheSendingThread exception: %s",
|
|
e,
|
|
exc_info=True)
|
|
|
|
def run_busy_loop(self, sock: zmq.Socket): # type: ignore
|
|
encoder = msgspec.msgpack.Encoder()
|
|
encoded_data = encoder.encode(self.metadata)
|
|
size_in_bytes = len(encoded_data)
|
|
logger.debug("Size of encoded MooncakeAgentMetadata: %s bytes",
|
|
str(size_in_bytes))
|
|
|
|
decoder = msgspec.msgpack.Decoder(type=tuple)
|
|
while True:
|
|
try:
|
|
frames = sock.recv_multipart()
|
|
if len(frames) < 2:
|
|
logger.error("Invalid message format: %s", frames)
|
|
continue
|
|
|
|
identity = frames[0]
|
|
payload = [f for f in frames[1:] if f != b""]
|
|
if len(payload) != 1:
|
|
logger.error("Invalid message format: %s", frames)
|
|
continue
|
|
|
|
msg = decoder.decode(payload[0])
|
|
if msg[0] == GET_META_MSG:
|
|
sock.send_multipart((identity, b"", encoded_data))
|
|
elif msg[0] == DONE_RECVING_MSG:
|
|
logger.debug("Got DONE_RECVING_MSG for request %s", msg[1])
|
|
request_id = msg[1]
|
|
self.task_tracker.update_done_task_count(request_id)
|
|
# Acknowledge the request completion.
|
|
while True:
|
|
try:
|
|
# Send ACK to the sender.
|
|
sock.send_multipart(
|
|
(identity, b"", b"ACK"),
|
|
flags=zmq.NOBLOCK) # type: ignore
|
|
break
|
|
except zmq.Again: # type: ignore
|
|
# If the socket is not ready, retry sending.
|
|
logger.debug(
|
|
"Socket not ready, retrying to send ACK for "
|
|
"request %s", msg[1])
|
|
time.sleep(0.01)
|
|
else:
|
|
logger.error(
|
|
"Connection listener got unexpected message %s", msg)
|
|
except Exception as e:
|
|
logger.error("Connection listener got exception %s: %s",
|
|
type(e), e)
|
|
|
|
|
|
class KVCacheRecvingThread(threading.Thread):
|
|
|
|
def __init__(self, tp_rank: int, tp_size: int, _prefill_pp_size: int,
|
|
engine: TransferEngine, local_engine_id: str,
|
|
local_handshake_port: int,
|
|
local_kv_caches_base_addr: list[int], block_len: list[int],
|
|
ready_event: threading.Event, vllm_config: VllmConfig,
|
|
kv_caches: dict[str, Any]):
|
|
super().__init__(daemon=True, name="KVCacheRecvingThread")
|
|
self.tp_rank = tp_rank
|
|
self.tp_size = tp_size
|
|
self._prefill_pp_size = _prefill_pp_size
|
|
self.local_engine_id = local_engine_id
|
|
self.local_handshake_port = local_handshake_port
|
|
self.engine = engine
|
|
self.ready_event = ready_event
|
|
|
|
self.kv_caches = kv_caches
|
|
self.kv_caches_base_addr: dict[str, dict[int, list[int]]] = \
|
|
SizedDict()
|
|
self.kv_caches_base_addr[local_engine_id][local_handshake_port] = \
|
|
local_kv_caches_base_addr
|
|
self.remote_te_port: dict[str, dict[int, int]] = \
|
|
SizedDict()
|
|
self.block_len = block_len
|
|
# TODO(jianzs): find a better way to detect MLA.
|
|
self.use_mla = len(block_len) == 2
|
|
self.use_sparse = len(block_len) == 3
|
|
|
|
self.request_queue: queue.Queue[Any] = queue.Queue()
|
|
self.executor = ThreadPoolExecutor(max_workers=32)
|
|
|
|
self.task_tracker = KVCacheTaskTracker()
|
|
|
|
self.encoder = msgspec.msgpack.Encoder()
|
|
self.decoder = msgspec.msgpack.Decoder(MooncakeAgentMetadata)
|
|
self.remote_sockets_lock = threading.Lock()
|
|
self.remote_sockets: dict[ # type: ignore
|
|
str, deque[zmq.Socket]] = defaultdict( # type: ignore
|
|
deque)
|
|
self.remote_poller = zmq.Poller() # type: ignore
|
|
self.timeout = 1.0 # seconds
|
|
|
|
self.vllm_config = vllm_config
|
|
self.model_config = self.vllm_config.model_config
|
|
self.block_size = self.vllm_config.cache_config.block_size
|
|
if self.use_mla:
|
|
self.k_head_dim = self.model_config.hf_config.kv_lora_rank
|
|
self.v_head_dim = self.model_config.hf_config.qk_rope_head_dim
|
|
self.num_kv_heads = 1
|
|
else:
|
|
self.k_head_dim = self.model_config.hf_config.head_dim
|
|
self.v_head_dim = self.model_config.hf_config.head_dim
|
|
self.num_kv_heads = max(
|
|
self.model_config.hf_config.num_key_value_heads //
|
|
self.tp_size, 1)
|
|
|
|
def add_request(self, request_id: str, local_block_ids: list[int],
|
|
remote_block_ids: list[int], remote_engine_id: str,
|
|
remote_host: str, remote_handshake_port: int, offset: int,
|
|
tp_num_need_pulls: int, all_task_done: bool):
|
|
"""Add a new request to the queue for processing."""
|
|
logger.debug(f"Adding request {request_id} to the queue.")
|
|
self.request_queue.put({
|
|
"request_id": request_id,
|
|
"local_block_ids": local_block_ids,
|
|
"remote_block_ids": remote_block_ids,
|
|
"remote_engine_id": remote_engine_id,
|
|
"remote_host": remote_host,
|
|
"remote_handshake_port": remote_handshake_port,
|
|
"offset": offset,
|
|
"tp_num_need_pulls": tp_num_need_pulls,
|
|
"all_task_done": all_task_done
|
|
})
|
|
|
|
def get_and_clear_finished_requests(self) -> set[str]:
|
|
"""
|
|
Get and clear the requests that have been completed.
|
|
Returns:
|
|
A set of request IDs that have been completed.
|
|
"""
|
|
return self.task_tracker.get_and_clear_finished_requests()
|
|
|
|
def run(self):
|
|
"""Run the thread to handle KV cache transfer requests."""
|
|
self.ready_event.set()
|
|
while True:
|
|
try:
|
|
request_data = self.request_queue.get()
|
|
if request_data is None:
|
|
logger.warning("Received a None request!")
|
|
self.request_queue.task_done()
|
|
continue
|
|
self._handle_request(request_data)
|
|
except Exception as e:
|
|
logger.error(f"Error in KVCacheTransferThread: {e}")
|
|
|
|
def _handle_request(self, req_meta: dict[str, Any]):
|
|
request_id = req_meta["request_id"]
|
|
remote_host = req_meta["remote_host"]
|
|
remote_handshake_port = req_meta["remote_handshake_port"]
|
|
all_task_done = req_meta["all_task_done"]
|
|
|
|
try:
|
|
logger.debug(
|
|
f"Starting to transfer KV cache for request {request_id}.")
|
|
self._transfer_kv_cache(req_meta)
|
|
logger.debug(
|
|
f"Finished transferring KV cache for request {request_id}.")
|
|
except Exception as e:
|
|
logger.error("Failed to transfer KV cache for request "
|
|
f"{request_id}: {e}")
|
|
finally:
|
|
# Always send the done signal to the remote host to ensure proper
|
|
# resource cleanup. Failing to do so may cause a memory leak on the
|
|
# remote host.
|
|
self._send_done_recv_signal(request_id, remote_host,
|
|
remote_handshake_port)
|
|
if all_task_done:
|
|
self.task_tracker.update_done_task_count(request_id)
|
|
self.request_queue.task_done()
|
|
|
|
def _transfer_kv_cache(self, req_meta: dict[str, Any]):
|
|
"""Handle a KV cache transfer request."""
|
|
request_id = req_meta["request_id"]
|
|
remote_block_ids = req_meta["remote_block_ids"]
|
|
local_block_ids = req_meta["local_block_ids"]
|
|
remote_engine_id = req_meta["remote_engine_id"]
|
|
remote_host = req_meta["remote_host"]
|
|
remote_handshake_port = req_meta["remote_handshake_port"]
|
|
offset = req_meta["offset"]
|
|
tp_num_need_pulls = req_meta["tp_num_need_pulls"]
|
|
|
|
# Full prefix cache hit: do not need to read remote blocks, just notify
|
|
# P worker that we have the blocks we need.
|
|
num_local_blocks = len(local_block_ids)
|
|
if num_local_blocks == 0:
|
|
return
|
|
|
|
num_remote_blocks = len(remote_block_ids)
|
|
assert num_local_blocks <= num_remote_blocks
|
|
if num_local_blocks < num_remote_blocks:
|
|
remote_block_ids = remote_block_ids[-num_local_blocks:]
|
|
|
|
# Check if we have the remote metadata cached.
|
|
if remote_engine_id not in self.kv_caches_base_addr or \
|
|
remote_handshake_port not in self.kv_caches_base_addr[remote_engine_id]:
|
|
self._get_remote_metadata(remote_host, remote_handshake_port)
|
|
|
|
if tp_num_need_pulls == 1:
|
|
grouped_remote_block_ids, grouped_local_block_ids = \
|
|
group_concurrent_contiguous(remote_block_ids, local_block_ids)
|
|
else:
|
|
remote_block_ids = list(map(lambda x: [x], remote_block_ids))
|
|
local_block_ids = list(map(lambda x: [x], local_block_ids))
|
|
grouped_remote_block_ids, grouped_local_block_ids = remote_block_ids, local_block_ids
|
|
num_transfer_groups = len(grouped_remote_block_ids)
|
|
# tp_num_need_pulls: number of KV caches each Decode node needs to pull from each PP stage
|
|
# Due to GQA, different KV heads are distributed across different ranks, so there are offsets
|
|
# indicating which KV head to pull
|
|
global_offset = offset # Global offset of request across all ranks
|
|
prefill_pp_rank = offset // tp_num_need_pulls # PP rank where current request resides
|
|
inner_offset = offset % tp_num_need_pulls # Offset within each PP stage
|
|
|
|
remote_kv_caches_base_addrs = \
|
|
self.kv_caches_base_addr[remote_engine_id][remote_handshake_port]
|
|
num_layers = self.model_config.hf_config.num_hidden_layers
|
|
first_layer_index, end_layer_index = get_pp_indices(
|
|
num_layers, prefill_pp_rank, self._prefill_pp_size)
|
|
num_cache_per_layer = len(list(
|
|
self.kv_caches.values())[0]) # Number of KV caches per layer
|
|
local_kv_caches_base_addrs = \
|
|
self.kv_caches_base_addr[self.local_engine_id][self.local_handshake_port][first_layer_index*num_cache_per_layer : end_layer_index*num_cache_per_layer]
|
|
logger.debug(
|
|
f"transfer kv cache first_layer_index:{first_layer_index} , end_layer_index:{end_layer_index}"
|
|
)
|
|
remote_transfer_port = self.remote_te_port[remote_engine_id][
|
|
remote_handshake_port]
|
|
num_blocks = len(local_block_ids)
|
|
session_id = f"{remote_host}:{remote_transfer_port}"
|
|
|
|
req_start_time = time.perf_counter()
|
|
src_list, dst_list, length_list = [], [], []
|
|
for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate(
|
|
zip(local_kv_caches_base_addrs, remote_kv_caches_base_addrs)):
|
|
if self.use_mla:
|
|
block_len = (self.block_len[k % 2])
|
|
elif self.use_sparse:
|
|
block_len = (self.block_len[k % 3])
|
|
else:
|
|
block_len = (self.block_len[0])
|
|
inner_block_len = block_len // tp_num_need_pulls
|
|
for remote_block_id, local_block_id in zip(
|
|
grouped_remote_block_ids, grouped_local_block_ids):
|
|
src = src_layer_base_addr + local_block_id[
|
|
0] * block_len + inner_offset * inner_block_len
|
|
dst = dst_layer_base_addr + remote_block_id[0] * inner_block_len
|
|
length = inner_block_len * len(local_block_id)
|
|
src_list.append(src)
|
|
dst_list.append(dst)
|
|
length_list.append(length)
|
|
|
|
ret = self.engine.batch_transfer_sync_read(session_id, src_list,
|
|
dst_list, length_list)
|
|
if ret < 0:
|
|
logger.error("Mooncake transfer failed for request %s",
|
|
req_meta["request_id"])
|
|
raise RuntimeError(f"Mooncake transfer failed, ret: {ret}")
|
|
|
|
req_end_time = time.perf_counter()
|
|
req_transfer_elapsed = (req_end_time - req_start_time) * 1000
|
|
logger.info(
|
|
"KV cache transfer for request %s took %.2f ms (%d groups,"
|
|
" %d blocks). local_ip %s local_device_id %s remote_session_id %s",
|
|
request_id, req_transfer_elapsed, num_transfer_groups, num_blocks,
|
|
get_ip(), self.tp_rank, session_id)
|
|
|
|
# Determine if the current position is the offset position at the end of the KV transmission.
|
|
is_kv_transfer_end = (
|
|
global_offset == tp_num_need_pulls * self._prefill_pp_size - 1)
|
|
need_cat_cache = tp_num_need_pulls > 1 and is_kv_transfer_end
|
|
# need_nz_cache maybe caused error in non-MLA models
|
|
if need_cat_cache:
|
|
self._cat_kv_cache(grouped_local_block_ids, tp_num_need_pulls)
|
|
|
|
def _cat_kv_cache(self, block_ids: list[list[int]],
|
|
tp_num_need_pulls: int):
|
|
# Get necessary parameters
|
|
k_cache = list(self.kv_caches.values())[0][0]
|
|
dtype = k_cache.dtype
|
|
device = k_cache.device
|
|
head_dim = self.model_config.hf_config.head_dim
|
|
block_size = self.vllm_config.cache_config.block_size
|
|
num_kv_head = max(
|
|
self.model_config.hf_config.num_key_value_heads // self.tp_size, 1)
|
|
|
|
flat_block_ids = [item for sublist in block_ids for item in sublist]
|
|
block_ids_tensor = torch.tensor(flat_block_ids, dtype=torch.int32)
|
|
num_blocks = len(flat_block_ids)
|
|
block_len = num_blocks * block_size
|
|
|
|
# Create device tensors for copy operations
|
|
block_table = block_ids_tensor.view(1, -1).to(device=device)
|
|
block_len_tensor = torch.tensor([block_len],
|
|
dtype=torch.int32).to(device=device)
|
|
seq_start_tensor = torch.tensor([0],
|
|
dtype=torch.int32).to(device=device)
|
|
|
|
# Initialize buffers
|
|
k_buffer = torch.empty(block_len,
|
|
num_kv_head,
|
|
head_dim,
|
|
dtype=dtype,
|
|
device=device)
|
|
v_buffer = torch.empty(block_len,
|
|
num_kv_head,
|
|
head_dim,
|
|
dtype=dtype,
|
|
device=device)
|
|
|
|
# Create slot mapping for reshape operations
|
|
block_offsets = torch.arange(0, block_size, dtype=torch.int32)
|
|
slot_mapping = (block_offsets.reshape(
|
|
(1, block_size)) + block_ids_tensor.reshape(
|
|
(num_blocks, 1)) * block_size)
|
|
slot_mapping = slot_mapping.flatten().to(device=device)
|
|
|
|
# Process each layer in the KV cache
|
|
for _, (k_cache_layer, v_cache_layer) in self.kv_caches.items():
|
|
# Load cache data into buffers
|
|
torch_npu.atb.npu_paged_cache_load(
|
|
k_cache_layer,
|
|
v_cache_layer,
|
|
block_table,
|
|
block_len_tensor,
|
|
seq_starts=seq_start_tensor,
|
|
key=k_buffer,
|
|
value=v_buffer,
|
|
)
|
|
|
|
# Transpose KV cache
|
|
k_buffer = self._transpose_kv_cache_between_head(
|
|
k_buffer, num_blocks, block_size, block_len, num_kv_head,
|
|
tp_num_need_pulls)
|
|
v_buffer = self._transpose_kv_cache_between_head(
|
|
v_buffer, num_blocks, block_size, block_len, num_kv_head,
|
|
tp_num_need_pulls)
|
|
|
|
# Reshape and cache the processed buffers
|
|
torch_npu._npu_reshape_and_cache(
|
|
key=k_buffer,
|
|
value=v_buffer,
|
|
key_cache=k_cache_layer,
|
|
value_cache=v_cache_layer,
|
|
slot_indices=slot_mapping,
|
|
)
|
|
|
|
# Clean up buffers
|
|
del k_buffer, v_buffer
|
|
|
|
def _transpose_kv_cache_between_head(
|
|
self, buffer: torch.Tensor, num_blocks: int, block_size: int,
|
|
block_len: int, num_kv_head: int,
|
|
tp_num_need_pulls: int) -> torch.Tensor:
|
|
buffer = buffer.view(num_blocks, tp_num_need_pulls, block_size, -1)
|
|
buffer.transpose_(1, 2)
|
|
return buffer.contiguous().view(block_len, num_kv_head, -1)
|
|
|
|
def _get_remote_metadata(self, remote_host: str,
|
|
remote_handshake_port: int) -> None:
|
|
"""Get the metadata from the remote host."""
|
|
sock: Optional[zmq.Socket] = None # type: ignore
|
|
try:
|
|
sock = self._get_remote_socket(remote_host, remote_handshake_port)
|
|
ensure_zmq_send(sock, self.encoder.encode((GET_META_MSG, "")))
|
|
metadata_bytes = ensure_zmq_recv(sock, self.remote_poller)
|
|
agent_meta = self.decoder.decode(metadata_bytes)
|
|
engine_id = agent_meta.engine_id
|
|
assert engine_id != self.local_engine_id, (
|
|
f"Conflict engine id {engine_id} with local engine id "
|
|
f"{self.local_engine_id}.")
|
|
self.kv_caches_base_addr[engine_id][remote_handshake_port] = \
|
|
agent_meta.kv_caches_base_addr
|
|
self.remote_te_port[engine_id][remote_handshake_port] = \
|
|
agent_meta.te_rpc_port
|
|
finally:
|
|
if sock is not None:
|
|
self._return_remote_socket(sock, remote_host,
|
|
remote_handshake_port)
|
|
logger.debug("Returned socket to pool for %s:%d", remote_host,
|
|
remote_handshake_port)
|
|
|
|
def _send_done_recv_signal(self, request_id: str, remote_host: str,
|
|
remote_handshake_port: int):
|
|
logger.debug("Sending done recving signal for request %s to %s:%d",
|
|
request_id, remote_host, remote_handshake_port)
|
|
sock: Optional[zmq.Socket] = None # type: ignore
|
|
try:
|
|
sock = self._get_remote_socket(remote_host, remote_handshake_port)
|
|
data_bytes = self.encoder.encode((DONE_RECVING_MSG, request_id))
|
|
ensure_zmq_send(sock, data_bytes)
|
|
resp = ensure_zmq_recv(sock,
|
|
self.remote_poller,
|
|
timeout=self.timeout)
|
|
logger.debug(
|
|
f"Received response for request {request_id}: {resp.decode('utf-8')}"
|
|
)
|
|
if resp != b"ACK":
|
|
logger.error("Failed to receive ACK for request %s from %s:%d",
|
|
request_id, remote_host, remote_handshake_port)
|
|
raise RuntimeError(
|
|
f"Failed to receive ACK, resp: {resp.decode('utf-8')}")
|
|
finally:
|
|
if sock is not None:
|
|
self._return_remote_socket(sock, remote_host,
|
|
remote_handshake_port)
|
|
logger.debug("Returned socket to pool for %s:%d", remote_host,
|
|
remote_handshake_port)
|
|
|
|
def _get_remote_socket(
|
|
self, remote_host: str,
|
|
remote_handshake_port: int) -> zmq.Socket: # type: ignore
|
|
"""Get a socket to the remote host."""
|
|
remote_path = make_zmq_path("tcp", remote_host, remote_handshake_port)
|
|
with self.remote_sockets_lock:
|
|
if self.remote_sockets[remote_path]:
|
|
return self.remote_sockets[remote_path].popleft()
|
|
|
|
ctx = zmq.Context() # type: ignore
|
|
sock = make_zmq_socket(
|
|
ctx=ctx,
|
|
path=remote_path,
|
|
socket_type=zmq.REQ, # type: ignore
|
|
bind=False)
|
|
sock.setsockopt(
|
|
zmq.SNDTIMEO, # type: ignore
|
|
int(self.timeout * 1000))
|
|
self.remote_poller.register(sock, zmq.POLLIN) # type: ignore
|
|
return sock
|
|
|
|
def _return_remote_socket(
|
|
self,
|
|
sock: zmq.Socket, # type: ignore
|
|
remote_host: str,
|
|
remote_handshake_port: int) -> None:
|
|
"""Return the remote socket to the pool."""
|
|
remote_path = make_zmq_path("tcp", remote_host, remote_handshake_port)
|
|
with self.remote_sockets_lock:
|
|
self.remote_sockets[remote_path].append(sock)
|
|
|
|
|
|
class MooncakeConnectorMetadata(KVConnectorMetadata):
|
|
|
|
def __init__(self):
|
|
self.requests: dict[str, ReqMeta] = {}
|
|
self.requests_to_send: dict[str, float] = {}
|
|
|
|
def add_new_req(
|
|
self,
|
|
request_id: str,
|
|
local_block_ids: list[int],
|
|
kv_transfer_params: dict[str, Any],
|
|
):
|
|
self.requests[request_id] = ReqMeta(
|
|
local_block_ids=local_block_ids,
|
|
remote_block_ids=kv_transfer_params["remote_block_ids"],
|
|
remote_engine_id=kv_transfer_params["remote_engine_id"],
|
|
remote_host=kv_transfer_params["remote_host"],
|
|
remote_port=kv_transfer_params["remote_port"],
|
|
remote_pcp_size=kv_transfer_params.get("remote_pcp_size", 1),
|
|
remote_dcp_size=kv_transfer_params.get("remote_dcp_size", 1),
|
|
remote_multi_nodes_meta_mapping=kv_transfer_params.get(
|
|
"remote_multi_nodes_meta_mapping", {}),
|
|
)
|
|
|
|
|
|
class MooncakeConnector(KVConnectorBase_V1):
|
|
|
|
def __init__(self,
|
|
vllm_config: VllmConfig,
|
|
role: KVConnectorRole,
|
|
kv_cache_config: Optional[KVCacheConfig] = None):
|
|
assert vllm_config.kv_transfer_config is not None
|
|
self.engine_id = vllm_config.kv_transfer_config.engine_id
|
|
self._connector_metadata = MooncakeConnectorMetadata()
|
|
|
|
if role == KVConnectorRole.SCHEDULER:
|
|
self.connector_scheduler: Optional[MooncakeConnectorScheduler] = \
|
|
MooncakeConnectorScheduler(vllm_config, str(self.engine_id))
|
|
self.connector_worker: Optional[MooncakeConnectorWorker] = None
|
|
elif role == KVConnectorRole.WORKER:
|
|
self.connector_scheduler = None
|
|
self.connector_worker = MooncakeConnectorWorker(
|
|
vllm_config, str(self.engine_id))
|
|
|
|
############################################################
|
|
# Scheduler Side Methods
|
|
############################################################
|
|
|
|
def get_num_new_matched_tokens(
|
|
self, request: "Request",
|
|
num_computed_tokens: int) -> tuple[int, bool]:
|
|
assert self.connector_scheduler is not None
|
|
return self.connector_scheduler.get_num_new_matched_tokens(
|
|
request, num_computed_tokens)
|
|
|
|
def update_state_after_alloc(self, request: "Request",
|
|
blocks: "KVCacheBlocks",
|
|
num_external_tokens: int):
|
|
assert self.connector_scheduler is not None
|
|
return self.connector_scheduler.update_state_after_alloc(
|
|
request, blocks, num_external_tokens)
|
|
|
|
def build_connector_meta(
|
|
self,
|
|
scheduler_output: SchedulerOutput,
|
|
) -> KVConnectorMetadata:
|
|
assert self.connector_scheduler is not None
|
|
return self.connector_scheduler.build_connector_meta(scheduler_output)
|
|
|
|
def request_finished(
|
|
self,
|
|
request: "Request",
|
|
block_ids: list[int],
|
|
) -> tuple[bool, Optional[dict[str, Any]]]:
|
|
assert self.connector_scheduler is not None
|
|
return self.connector_scheduler.request_finished(request, block_ids)
|
|
|
|
############################################################
|
|
# Worker Side Methods
|
|
############################################################
|
|
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
|
assert self.connector_worker is not None
|
|
self.connector_worker.register_kv_caches(kv_caches)
|
|
|
|
def get_finished(self,
|
|
finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
|
|
"""Get the finished recving and sending requests."""
|
|
assert self.connector_worker is not None
|
|
return self.connector_worker.get_finished()
|
|
|
|
def start_load_kv(self, forward_context: "ForwardContext",
|
|
**kwargs) -> None:
|
|
assert self.connector_worker is not None
|
|
assert isinstance(self._connector_metadata, MooncakeConnectorMetadata)
|
|
self.connector_worker.start_load_kv(self._connector_metadata)
|
|
|
|
def wait_for_layer_load(self, layer_name: str) -> None:
|
|
"""MooncakeConnector does not do layerwise saving."""
|
|
pass
|
|
|
|
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
|
attn_metadata: "AttentionMetadata", **kwargs) -> None:
|
|
"""MooncakeConnector does not save explicitly."""
|
|
pass
|
|
|
|
def wait_for_save(self):
|
|
"""MooncakeConnector does not save explicitly."""
|
|
pass
|
|
|
|
def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None:
|
|
"""
|
|
Get the KVConnector handshake metadata for this connector.
|
|
This metadata is used for out-of-band connector handshake
|
|
between P/D workers.
|
|
|
|
Returns:
|
|
KVConnectorHandshakeMetadata: the handshake metadata.
|
|
None if no handshake metadata is available.
|
|
"""
|
|
assert self.connector_worker is not None
|
|
return self.connector_worker.xfer_handshake_metadata
|
|
|
|
def set_xfer_handshake_metadata(
|
|
self, metadata: dict[int, KVConnectorHandshakeMetadata]) -> None:
|
|
"""
|
|
Set the KV connector handshake metadata for this connector.
|
|
|
|
Args:
|
|
metadata (dict): the handshake metadata to set.
|
|
"""
|
|
assert self.connector_scheduler is not None
|
|
self.connector_scheduler.set_xfer_handshake_metadata(metadata)
|
|
|
|
|
|
class MooncakeConnectorScheduler:
|
|
"""Implementation of Scheduler side methods"""
|
|
|
|
def __init__(self, vllm_config: VllmConfig, engine_id: str):
|
|
self.vllm_config = vllm_config
|
|
init_ascend_config(vllm_config)
|
|
self.ascend_config = get_ascend_config()
|
|
self.block_size = vllm_config.cache_config.block_size
|
|
self.engine_id = engine_id
|
|
self.local_ip = get_ip()
|
|
logger.info("Initializing Mooncake Scheduler %s", engine_id)
|
|
|
|
self.side_channel_host = get_ip()
|
|
self.pcp_size = vllm_config.parallel_config.prefill_context_parallel_size
|
|
self.dcp_size = vllm_config.parallel_config.decode_context_parallel_size
|
|
self.max_device_id = vllm_config.parallel_config.tensor_parallel_size * \
|
|
vllm_config.parallel_config.data_parallel_size * \
|
|
self.pcp_size * \
|
|
vllm_config.parallel_config.pipeline_parallel_size
|
|
|
|
# Handshake base port
|
|
self.side_channel_port = (
|
|
vllm_config.kv_transfer_config.kv_port +
|
|
vllm_config.parallel_config.data_parallel_rank *
|
|
vllm_config.parallel_config.tensor_parallel_size *
|
|
vllm_config.parallel_config.pipeline_parallel_size * self.pcp_size)
|
|
# Requests that need to start recv.
|
|
# New requests are added by update_state_after_alloc in
|
|
# the scheduler. Used to make metadata passed to Worker.
|
|
self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {}
|
|
self._reqs_need_send: dict[str, float] = {}
|
|
|
|
# master-slave meta information for cross-nodes
|
|
self.multi_nodes_meta_mapping: dict[str, dict[str, Any]] = {}
|
|
|
|
def get_num_new_matched_tokens(
|
|
self, request: "Request",
|
|
num_computed_tokens: int) -> tuple[int, bool]:
|
|
"""
|
|
For remote prefill, pull all prompt blocks from remote
|
|
asynchronously relative to engine execution.
|
|
|
|
Args:
|
|
request (Request): the request object.
|
|
num_computed_tokens (int): the number of locally
|
|
computed tokens for this request
|
|
Returns:
|
|
* the number of tokens that can be loaded from the
|
|
external KV cache beyond what is already computed.
|
|
* true if the external KV cache tokens will be loaded
|
|
asynchronously (between scheduler steps).
|
|
"""
|
|
|
|
params = request.kv_transfer_params
|
|
logger.debug(
|
|
"MooncakeConnector get_num_new_matched_tokens: "
|
|
"num_computed_tokens=%s, kv_transfer_params=%s",
|
|
num_computed_tokens, params)
|
|
|
|
if params is not None and params.get("do_remote_prefill"):
|
|
# Remote prefill: get all prompt blocks from remote.
|
|
assert num_computed_tokens % self.block_size == 0
|
|
# Note: We use the full token count as transmit data here.
|
|
count = max(len(request.prompt_token_ids) - num_computed_tokens, 0)
|
|
return count, count > 0
|
|
|
|
# No remote prefill for this request.
|
|
return 0, False
|
|
|
|
def update_state_after_alloc(self, request: "Request",
|
|
blocks: "KVCacheBlocks",
|
|
num_external_tokens: int):
|
|
|
|
params = request.kv_transfer_params
|
|
logger.debug(
|
|
"MooncakeConnector update_state_after_alloc: "
|
|
"num_external_tokens=%s, kv_transfer_params=%s",
|
|
num_external_tokens, params)
|
|
|
|
if params is not None and params.get("do_remote_prefill"):
|
|
if params.get("remote_block_ids"):
|
|
if all(p in params for p in ("remote_engine_id", "remote_host",
|
|
"remote_port")):
|
|
local_block_ids = (blocks.get_unhashed_block_ids()
|
|
if num_external_tokens > 0 else [])
|
|
# Get unhashed blocks to pull from remote.
|
|
self._reqs_need_recv[request.request_id] = (
|
|
request, local_block_ids)
|
|
else:
|
|
logger.warning(
|
|
"Got invalid KVTransferParams: %s. This "
|
|
"request will not utilize KVTransfer", params)
|
|
else:
|
|
assert num_external_tokens == 0
|
|
# Only trigger 1 KV transfer per request.
|
|
params["do_remote_prefill"] = False
|
|
|
|
def build_connector_meta(
|
|
self,
|
|
scheduler_output: SchedulerOutput,
|
|
) -> KVConnectorMetadata:
|
|
meta = MooncakeConnectorMetadata()
|
|
|
|
# Loop through scheduled reqs and convert to ReqMeta.
|
|
for req_id, (req, block_ids) in self._reqs_need_recv.items():
|
|
assert req.kv_transfer_params is not None
|
|
# For the case where there are no remote blocks to pull
|
|
# (block_ids is empty), we don't need to schedule
|
|
# an async read on the worker side.
|
|
meta.add_new_req(
|
|
request_id=req_id,
|
|
local_block_ids=block_ids,
|
|
kv_transfer_params=req.kv_transfer_params,
|
|
)
|
|
|
|
# Clear the list once workers start the transfers
|
|
self._reqs_need_recv.clear()
|
|
meta.requests_to_send = self._reqs_need_send
|
|
self._reqs_need_send = {}
|
|
|
|
return meta
|
|
|
|
def request_finished(
|
|
self,
|
|
request: "Request",
|
|
block_ids: list[int],
|
|
) -> tuple[bool, Optional[dict[str, Any]]]:
|
|
"""
|
|
Once a request is finished, determine whether request blocks
|
|
should be freed now or will be sent asynchronously and freed later.
|
|
"""
|
|
|
|
params = request.kv_transfer_params
|
|
logger.debug(
|
|
"MooncakeConnector request_finished, request_status=%s, "
|
|
"kv_transfer_params=%s", request.status, params)
|
|
|
|
if (params is None or not params.get("do_remote_decode")
|
|
or request.status != RequestStatus.FINISHED_LENGTH_CAPPED):
|
|
return False, None
|
|
|
|
computed_block_ids = block_ids
|
|
delay_free_blocks = len(computed_block_ids) > 0
|
|
if delay_free_blocks:
|
|
logger.info("Delaying free of %d blocks for request %s",
|
|
len(computed_block_ids), request.request_id)
|
|
self._reqs_need_send[request.request_id] = time.time()
|
|
|
|
return delay_free_blocks, dict(
|
|
do_remote_prefill=True,
|
|
do_remote_decode=False,
|
|
remote_block_ids=computed_block_ids,
|
|
remote_engine_id=self.engine_id,
|
|
remote_host=self.side_channel_host,
|
|
remote_port=self.side_channel_port,
|
|
remote_pcp_size=self.pcp_size,
|
|
remote_dcp_size=self.dcp_size,
|
|
last_token_id=request.output_token_ids[-1],
|
|
remote_multi_nodes_meta_mapping=self.multi_nodes_meta_mapping,
|
|
)
|
|
|
|
def set_xfer_handshake_metadata(
|
|
self, metadata: dict[int, KVConnectorHandshakeMetadata]) -> None:
|
|
"""
|
|
Set the KV connector handshake metadata for this connector.
|
|
|
|
Args:
|
|
metadata (dict): the handshake metadata to set.
|
|
"""
|
|
for local_rank, rank_metadata in metadata.items():
|
|
self.multi_nodes_meta_mapping[str(local_rank)] = {
|
|
"host": rank_metadata.local_ip,
|
|
"engine_id": rank_metadata.engine_id,
|
|
}
|
|
|
|
|
|
class MooncakeConnectorWorker:
|
|
"""Implementation of Worker side methods"""
|
|
|
|
def __init__(self, vllm_config: VllmConfig, engine_id: str):
|
|
self._get_prefill_decode_size(vllm_config)
|
|
os.environ["ASCEND_TRANSFER_TIMEOUT"] = str(
|
|
get_transfer_timeout_value())
|
|
if self._prefill_tp_size < self._decode_tp_size:
|
|
raise ValueError(
|
|
f"prefill_tp_size: {self._prefill_tp_size} must be greater than"
|
|
f" or equal to the decode_tp_size: {self._decode_tp_size}")
|
|
|
|
# Metadata.
|
|
self.vllm_config = vllm_config
|
|
self.ascend_config = get_ascend_config()
|
|
self.engine_id = engine_id
|
|
self.tp_rank = get_tensor_model_parallel_rank()
|
|
self.tp_size = vllm_config.parallel_config.tensor_parallel_size
|
|
self.tp_group = get_tp_group()
|
|
self.pp_rank = get_pp_group().rank_in_group
|
|
self.dp_rank = vllm_config.parallel_config.data_parallel_rank_local
|
|
self.dp_size = vllm_config.parallel_config.data_parallel_size_local
|
|
self.pp_size = vllm_config.parallel_config.pipeline_parallel_size
|
|
self.kv_caches: dict[str, torch.Tensor] = {}
|
|
self.side_channel_host = get_ip()
|
|
self.pcp_size = get_prefill_context_model_parallel_world_size(
|
|
) if prefill_context_parallel_enable() else 1
|
|
# Assert that pp_size and pcp_size cannot both be greater than 1
|
|
assert not (self.pp_size > 1 and self.pcp_size
|
|
> 1), "pp and pcp cannot open in same time"
|
|
self.pcp_rank = get_prefill_context_model_parallel_rank(
|
|
) if self.pcp_size > 1 else 0
|
|
self.dcp_size = get_decode_context_model_parallel_world_size()
|
|
self.dcp_rank = get_decode_context_model_parallel_rank(
|
|
) if self.dcp_size > 1 else 0
|
|
|
|
self.max_device_id = self.tp_size * self.dp_size * self.pcp_size * self.pp_size
|
|
self.kv_role = vllm_config.kv_transfer_config.kv_role
|
|
self.num_key_value_heads = self.vllm_config.model_config.hf_config.num_key_value_heads
|
|
|
|
# Handshake base port
|
|
self.side_channel_port = (
|
|
vllm_config.kv_transfer_config.kv_port +
|
|
vllm_config.parallel_config.data_parallel_rank *
|
|
vllm_config.parallel_config.tensor_parallel_size *
|
|
vllm_config.parallel_config.pipeline_parallel_size * self.pcp_size)
|
|
device_index = (self.pp_rank +
|
|
self.pcp_rank) * self.tp_size + self.tp_rank
|
|
self.handshake_port = self.side_channel_port + device_index
|
|
self.sockets: dict = {}
|
|
self.engine = global_te.get_transfer_engine(self.side_channel_host,
|
|
device_name=None)
|
|
self.te_rpc_port = self.engine.get_rpc_port()
|
|
|
|
# Background thread for sending or receiving KV caches.
|
|
self.kv_send_thread: Optional[KVCacheSendingThread] = None
|
|
self.kv_recv_thread: Optional[KVCacheRecvingThread] = None
|
|
|
|
# Handshake metadata of this worker
|
|
self.xfer_handshake_metadata: MooncakeAgentMetadata | None = None
|
|
|
|
# kv_transfer variables
|
|
self.vllm_config = vllm_config
|
|
self.block_size = vllm_config.cache_config.block_size
|
|
if self.vllm_config.model_config.is_deepseek_mla:
|
|
self.tp_num_need_pulls = 1
|
|
else:
|
|
num_d_block_heads = max(1,
|
|
self.num_key_value_heads // self.tp_size)
|
|
num_p_block_heads = max(
|
|
1, self.num_key_value_heads // self._prefill_tp_size)
|
|
self.tp_num_need_pulls = num_d_block_heads // num_p_block_heads
|
|
|
|
def _get_prefill_decode_size(self, vllm_config: VllmConfig):
|
|
# get prefill tp and dp size from extra config
|
|
prefill_parallel_config: dict[
|
|
str, Any] = vllm_config.kv_transfer_config.get_from_extra_config(
|
|
"prefill", {})
|
|
|
|
assert "tp_size" in prefill_parallel_config.keys()
|
|
self._prefill_tp_size = prefill_parallel_config["tp_size"]
|
|
|
|
assert "dp_size" in prefill_parallel_config.keys()
|
|
self._prefill_dp_size = prefill_parallel_config["dp_size"]
|
|
# get prefill pp size from extra config
|
|
self._prefill_pp_size = prefill_parallel_config.get("pp_size", 1)
|
|
# get decode tp and dp size from extra config
|
|
decode_parallel_config: dict[
|
|
str, Any] = vllm_config.kv_transfer_config.get_from_extra_config(
|
|
"decode", {})
|
|
assert "tp_size" in decode_parallel_config.keys()
|
|
self._decode_tp_size = decode_parallel_config["tp_size"]
|
|
assert "dp_size" in decode_parallel_config.keys()
|
|
self._decode_dp_size = decode_parallel_config["dp_size"]
|
|
# get prefill pp size from extra config
|
|
self._decode_pp_size = decode_parallel_config.get("pp_size", 1)
|
|
assert self._decode_pp_size == 1, "decode pp size must be 1"
|
|
|
|
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
|
"""Register the KV Cache data."""
|
|
|
|
_, first_kv_cache_tuple = next(iter(kv_caches.items()))
|
|
first_kv_cache = first_kv_cache_tuple[0]
|
|
|
|
# TODO(tms): Find a more robust way to detect and handle MLA
|
|
self.use_mla = first_kv_cache_tuple[0].size(
|
|
-1) != first_kv_cache_tuple[1].size(-1) and len(
|
|
first_kv_cache_tuple) == 2
|
|
self.use_sparse = len(first_kv_cache_tuple) == 3
|
|
if self.use_mla:
|
|
# MLA case.[num_block, block_size, 1, hidden_dim]
|
|
self.num_blocks = first_kv_cache.shape[0]
|
|
block_rank = 3 # [block_size, latent_dim]
|
|
block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:]
|
|
block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:]
|
|
self.block_len = [
|
|
first_kv_cache[0].element_size() * math.prod(block_shape_norm),
|
|
first_kv_cache[1].element_size() * math.prod(block_shape_pe)
|
|
]
|
|
logger.info(
|
|
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s",
|
|
self.num_blocks, block_shape_norm, block_shape_pe)
|
|
elif self.use_sparse:
|
|
self.num_blocks = first_kv_cache.shape[0]
|
|
block_rank = 3 # [block_size, latent_dim]
|
|
block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:]
|
|
block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:]
|
|
block_shape_k = first_kv_cache_tuple[2].shape[-block_rank:]
|
|
self.block_len = [
|
|
first_kv_cache[0].element_size() * math.prod(block_shape_norm),
|
|
first_kv_cache[1].element_size() * math.prod(block_shape_pe),
|
|
first_kv_cache[2].element_size() * math.prod(block_shape_k)
|
|
]
|
|
logger.info(
|
|
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s, block_shape_k: %s",
|
|
self.num_blocks, block_shape_norm, block_shape_pe,
|
|
block_shape_k)
|
|
else:
|
|
# eager:[num_block, block_size, num_head, hidden_dim]
|
|
# torchair:[num_block, block_size, num_head*hidden_dim]
|
|
self.num_blocks = first_kv_cache.shape[0]
|
|
kv_elem_size = first_kv_cache.element_size()
|
|
block_rank = len(
|
|
first_kv_cache.shape
|
|
) - 1 # [block_size, kv_heads, head_dim] or [block_size, kv_heads*head_dim]
|
|
block_shape = first_kv_cache.shape[-block_rank:]
|
|
self.block_len = [kv_elem_size * math.prod(block_shape)]
|
|
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks,
|
|
block_shape)
|
|
logger.info(
|
|
"Registering KV_Caches. use_mla: %s, use_sparse: %s, shape %s",
|
|
self.use_mla, self.use_sparse, first_kv_cache.shape)
|
|
|
|
self.kv_caches = kv_caches
|
|
kv_caches_base_addr = []
|
|
ptrs = []
|
|
lengths = []
|
|
for cache_or_caches in kv_caches.values():
|
|
# Normalize to always be a list of caches
|
|
if self.use_mla:
|
|
for i, cache in enumerate(cache_or_caches, 0):
|
|
base_addr = cache.data_ptr()
|
|
region_len = self.num_blocks * self.block_len[i % 2]
|
|
kv_caches_base_addr.append(base_addr)
|
|
ptrs.append(base_addr)
|
|
lengths.append(region_len)
|
|
elif self.use_sparse:
|
|
for i, cache in enumerate(cache_or_caches, 0):
|
|
base_addr = cache.data_ptr()
|
|
region_len = self.num_blocks * self.block_len[i % 3]
|
|
kv_caches_base_addr.append(base_addr)
|
|
ptrs.append(base_addr)
|
|
lengths.append(region_len)
|
|
else:
|
|
cache_list = [
|
|
cache_or_caches
|
|
] if self.use_mla or self.use_sparse else cache_or_caches
|
|
for cache in cache_list:
|
|
base_addr = cache.data_ptr()
|
|
region_len = self.num_blocks * self.block_len[0]
|
|
kv_caches_base_addr.append(base_addr)
|
|
ptrs.append(base_addr)
|
|
lengths.append(region_len)
|
|
global_te.register_buffer(ptrs, lengths)
|
|
# After KV Caches registered, start the sending or receiving thread.
|
|
metadata = MooncakeAgentMetadata(
|
|
engine_id=self.engine_id,
|
|
te_rpc_port=self.te_rpc_port,
|
|
kv_caches_base_addr=kv_caches_base_addr,
|
|
num_blocks=self.num_blocks,
|
|
local_ip=get_ip(),
|
|
)
|
|
self.xfer_handshake_metadata = metadata
|
|
|
|
ready_event = threading.Event()
|
|
if self.kv_role == 'kv_producer':
|
|
self.kv_send_thread = KVCacheSendingThread(
|
|
self.vllm_config, self.tp_rank, self._prefill_tp_size,
|
|
self.engine_id, self.side_channel_host, self.side_channel_port,
|
|
metadata, ready_event, self.kv_caches, self.pcp_rank)
|
|
self.kv_send_thread.start()
|
|
else:
|
|
self.kv_recv_thread = KVCacheRecvingThread(
|
|
self.tp_rank, self.tp_size, self._prefill_pp_size, self.engine,
|
|
self.engine_id, self.handshake_port, kv_caches_base_addr,
|
|
self.block_len, ready_event, self.vllm_config, self.kv_caches)
|
|
self.kv_recv_thread.start()
|
|
|
|
start_wait_time = time.time()
|
|
thread = self.kv_send_thread if self.kv_role == 'kv_producer' else self.kv_recv_thread
|
|
assert thread is not None
|
|
while not ready_event.is_set():
|
|
if not thread.is_alive():
|
|
raise RuntimeError(
|
|
"KV Cache sending/receiving thread failed to start.")
|
|
if time.time() - start_wait_time > 5 * 60:
|
|
raise RuntimeError(
|
|
"Timeout waiting for KV Cache thread to be ready.")
|
|
time.sleep(3)
|
|
|
|
def get_finished(self) -> tuple[set[str], set[str]]:
|
|
done_sending = (
|
|
self.kv_send_thread.
|
|
get_and_clear_finished_requests( # type: ignore[union-attr]
|
|
) if self.kv_role == 'kv_producer' else set())
|
|
done_recving = (
|
|
self.kv_recv_thread.
|
|
get_and_clear_finished_requests( # type: ignore[union-attr]
|
|
) if self.kv_role == 'kv_consumer' else set())
|
|
if self.tp_rank == 0:
|
|
logger.debug(
|
|
"Number of completed KV cache send requests: %d, receive "
|
|
"requests: %d", len(done_sending), len(done_recving))
|
|
return done_sending, done_recving
|
|
|
|
def _get_kv_split_metadata(
|
|
self,
|
|
req_id: str,
|
|
meta: ReqMeta,
|
|
) -> tuple[list[list[int]], list[list[int]], list[list[int]]]:
|
|
"""
|
|
In cp/dcp scenario, kv_cache may be split, so we need to pull multiple blocks from multiple remote P node.
|
|
Use this function to calculate remote port and remote block number of each remote P node that we need to pull.
|
|
"""
|
|
if meta.remote_pcp_size * meta.remote_dcp_size * self.pcp_size * self.dcp_size == 1:
|
|
choosen_rank_list = self._get_remote_rank(req_id)
|
|
remote_handshake_port_list = [[
|
|
x + meta.remote_port for x in choosen_rank_list
|
|
]]
|
|
local_block_ids_list, remote_block_ids_list = [
|
|
meta.local_block_ids
|
|
], [meta.remote_block_ids]
|
|
return remote_handshake_port_list, local_block_ids_list, remote_block_ids_list
|
|
|
|
if self.pcp_size == meta.remote_pcp_size and self.dcp_size == meta.remote_dcp_size:
|
|
# remote & local cp/dcp are equal, do kv transfer point-to-point
|
|
remote_kv_num = 1
|
|
remote_ports = [meta.remote_port + self.pcp_rank * self.tp_size + tp_offset \
|
|
for tp_offset in range(self.tp_rank, int(self._prefill_tp_size), self.tp_size)]
|
|
remote_block_nums = [len(meta.remote_block_ids)]
|
|
else:
|
|
assert self.pcp_size == 1
|
|
if self.use_mla:
|
|
assert (self.dcp_size == 1 and (self.tp_size == 1 or self.tp_size == self._prefill_tp_size)) or \
|
|
(self.dcp_size == meta.remote_dcp_size and self.tp_size == self._prefill_tp_size)
|
|
else:
|
|
assert self.tp_size == self._prefill_tp_size and (
|
|
self.dcp_size == 1
|
|
or self.dcp_size == meta.remote_dcp_size)
|
|
# remote & local cp/dcp are not equal, each D node needs to pull from pcp(*dcp) P nodes
|
|
# 1. for mla, support D pcp_size = 1, D dcp_size = (1 or P dcp_size)
|
|
# 2. for gqa, support D tp_size = P tp_size, D dcp_size = P dcp_size
|
|
remote_dcp_size = meta.remote_dcp_size // self.dcp_size
|
|
remote_kv_num = meta.remote_pcp_size * remote_dcp_size
|
|
cp_dcp_offsets = []
|
|
for cp_idx in range(meta.remote_pcp_size):
|
|
cp_offset = cp_idx * self._prefill_tp_size
|
|
cp_dcp_offsets += list(
|
|
range(cp_offset, cp_offset + remote_dcp_size))
|
|
tp_offset = self.tp_rank // remote_dcp_size * remote_dcp_size
|
|
remote_ports = [meta.remote_port + cp_dcp_offset + tp_offset \
|
|
for cp_dcp_offset in cp_dcp_offsets]
|
|
# recompute cp/dcp block assign here, maybe we can also pass it from P node meta
|
|
local_block_num = len(meta.local_block_ids)
|
|
remote_block_nums = [
|
|
local_block_num // (meta.remote_pcp_size * remote_dcp_size)
|
|
] * meta.remote_pcp_size * remote_dcp_size
|
|
num_remain_blocks = local_block_num % (meta.remote_pcp_size *
|
|
remote_dcp_size)
|
|
for i in range(num_remain_blocks):
|
|
remote_block_nums[i] += 1
|
|
# make sure the last block (which may be unfull) of P nodes is put to the last block of D node
|
|
remote_ports = remote_ports[
|
|
num_remain_blocks:] + remote_ports[:num_remain_blocks]
|
|
remote_block_nums = remote_block_nums[
|
|
num_remain_blocks:] + remote_block_nums[:num_remain_blocks]
|
|
|
|
remote_handshake_port_list = []
|
|
for remote_kv_id in range(remote_kv_num):
|
|
remote_handshake_port_list.append([remote_ports[remote_kv_id]])
|
|
|
|
# the local_block_ids_list and remote_block_ids_list are related with remote_handshake_port_list
|
|
# such as: local_block_ids_list[[1],[2],[5],[6]], remote_block_ids_list[[1],[1],[1],[1]],
|
|
# remote_handshake_port_list[[30000],[30001],[30004],[30005]]
|
|
# D rank will get remote block 1 in port 30004 and save it in local block 5
|
|
local_block_ids_list = []
|
|
remote_block_ids_list = []
|
|
local_block_offset = 0
|
|
for remote_kv_id in range(len(remote_handshake_port_list)):
|
|
num_blocks_to_pull = remote_block_nums[remote_kv_id]
|
|
remote_block_ids_list.append(
|
|
meta.remote_block_ids[:num_blocks_to_pull])
|
|
local_block_ids_list.append(
|
|
meta.local_block_ids[local_block_offset:local_block_offset +
|
|
num_blocks_to_pull])
|
|
local_block_offset += num_blocks_to_pull
|
|
assert local_block_offset == len(meta.local_block_ids), \
|
|
f"local_block_offset ({local_block_offset}) should equal with local_block_ids len ({len(meta.local_block_ids)})"
|
|
|
|
return remote_handshake_port_list, local_block_ids_list, remote_block_ids_list
|
|
|
|
def start_load_kv(self, metadata: MooncakeConnectorMetadata):
|
|
"""Start loading KV blocks from remote engine."""
|
|
for req_id, meta in metadata.requests.items():
|
|
logger.debug(
|
|
"start_load_kv for request %s from remote engine %s. "
|
|
"Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id,
|
|
meta.remote_engine_id, len(meta.local_block_ids),
|
|
len(meta.remote_block_ids))
|
|
|
|
if prefill_context_parallel_enable():
|
|
remote_handshake_port_list, local_block_ids_list, remote_block_ids_list = self._get_kv_split_metadata(
|
|
req_id, meta)
|
|
|
|
for pcp_dcp_rank in range(len(remote_handshake_port_list)):
|
|
if len(local_block_ids_list[pcp_dcp_rank]) + len(
|
|
remote_block_ids_list[pcp_dcp_rank]) == 0:
|
|
continue
|
|
for i in range(self.tp_num_need_pulls):
|
|
assert self.kv_recv_thread is not None
|
|
remote_host, remote_engine_id = self._get_remote_host_info_by_port(
|
|
meta.remote_port,
|
|
remote_handshake_port_list[pcp_dcp_rank][i],
|
|
meta.remote_host, meta.remote_engine_id,
|
|
meta.remote_multi_nodes_meta_mapping)
|
|
self.kv_recv_thread.add_request(
|
|
request_id=req_id,
|
|
local_block_ids=local_block_ids_list[pcp_dcp_rank],
|
|
remote_block_ids=remote_block_ids_list[
|
|
pcp_dcp_rank],
|
|
remote_engine_id=remote_engine_id,
|
|
remote_host=remote_host,
|
|
remote_handshake_port=remote_handshake_port_list[
|
|
pcp_dcp_rank][i],
|
|
offset=i,
|
|
tp_num_need_pulls=self.tp_num_need_pulls,
|
|
all_task_done=(
|
|
pcp_dcp_rank
|
|
== len(remote_handshake_port_list) - 1
|
|
and i == self.tp_num_need_pulls - 1))
|
|
else: #TODO: support prefill context parallel and pipeline parallel open at the same time
|
|
choosen_rank_list = self._get_remote_rank(req_id)
|
|
remote_handshake_port_list = [[x + meta.remote_port]
|
|
for x in choosen_rank_list]
|
|
for i in range(self.tp_num_need_pulls * self._prefill_pp_size):
|
|
assert self.kv_recv_thread is not None
|
|
remote_host, remote_engine_id = self._get_remote_host_info_by_port(
|
|
meta.remote_port, remote_handshake_port_list[i][0],
|
|
meta.remote_host, meta.remote_engine_id,
|
|
meta.remote_multi_nodes_meta_mapping)
|
|
self.kv_recv_thread.add_request(
|
|
request_id=req_id,
|
|
local_block_ids=meta.local_block_ids,
|
|
remote_block_ids=meta.remote_block_ids,
|
|
remote_engine_id=remote_engine_id,
|
|
remote_host=remote_host,
|
|
remote_handshake_port=remote_handshake_port_list[i][0],
|
|
offset=i,
|
|
tp_num_need_pulls=self.tp_num_need_pulls,
|
|
all_task_done=(i == self.tp_num_need_pulls *
|
|
self._prefill_pp_size - 1))
|
|
|
|
if self.kv_send_thread is not None:
|
|
for req_id, delay_start_time in metadata.requests_to_send.items():
|
|
if self.tp_rank in self._prefill_get_remote_rank(req_id):
|
|
self.kv_send_thread.add_delayed_request(
|
|
req_id, delay_start_time)
|
|
else:
|
|
self.kv_send_thread.add_not_transfer_request(req_id)
|
|
|
|
def _get_remote_host_info_by_port(self, base_port: int,
|
|
remote_handshake_port: int,
|
|
remote_host: str, remote_engine_id: str,
|
|
remote_multi_nodes_meta_mapping: dict):
|
|
rank = str(remote_handshake_port - base_port)
|
|
if remote_multi_nodes_meta_mapping is None or remote_multi_nodes_meta_mapping.get(
|
|
rank, None) is None:
|
|
return remote_host, remote_engine_id
|
|
info = remote_multi_nodes_meta_mapping[rank]
|
|
return info.get("host", remote_host), info.get("engine_id",
|
|
remote_engine_id)
|
|
|
|
def _prefill_get_remote_rank(self, req_id: str) -> List[int]:
|
|
return sum(self._get_remote_ranks_for_req(req_id), [])
|
|
|
|
def _get_remote_rank(self, req_id: str) -> List[int]:
|
|
return self._get_remote_ranks_for_req(req_id)[self.tp_rank]
|
|
|
|
def _get_remote_tp_ranks(self, tp_ori_data: np.ndarray,
|
|
rand_group_index: list[int],
|
|
num_groups: int) -> List[List[int]]:
|
|
# random split prefill tp list
|
|
tp_sampled_nums = []
|
|
if self._prefill_tp_size > self.num_key_value_heads or self.vllm_config.model_config.is_deepseek_mla or self.use_sparse:
|
|
tp_ori_data = tp_ori_data.reshape(-1, num_groups)
|
|
choosen_group = tp_ori_data[:, [rand_group_index]]
|
|
flattened = choosen_group.reshape(-1).tolist()
|
|
tp_sampled_nums = [
|
|
flattened[i:i + self.tp_num_need_pulls]
|
|
for i in range(0, len(flattened), self.tp_num_need_pulls)
|
|
]
|
|
# non-random split
|
|
else:
|
|
group_size = self._prefill_tp_size // self._decode_tp_size
|
|
for i in range(self._decode_tp_size):
|
|
slice = tp_ori_data[i * group_size:(i + 1) * group_size]
|
|
tp_sampled_nums.append(slice.tolist())
|
|
return tp_sampled_nums
|
|
|
|
def _get_remote_ranks_for_req(self, req_id: str) -> List[List[int]]:
|
|
# Divide the ports according to the TP within the PP
|
|
sampled_nums = []
|
|
if self._prefill_tp_size == self._decode_tp_size:
|
|
sampled_nums = list(
|
|
map(
|
|
lambda tp: [
|
|
tp + pp * self._prefill_tp_size
|
|
for pp in range(self._prefill_pp_size)
|
|
], range(self._prefill_tp_size)))
|
|
return sampled_nums
|
|
# use deepseek mla, num_key_value_heads == 128, but consider as 1
|
|
if self.vllm_config.model_config.is_deepseek_mla or self.use_sparse:
|
|
num_kv_head = 1
|
|
else:
|
|
num_kv_head = self.num_key_value_heads
|
|
ori_data = np.arange(self._prefill_tp_size * self._prefill_pp_size)
|
|
seed = string_to_int64_hash(req_id)
|
|
rand = random.Random(seed)
|
|
# random split prefill tp list
|
|
ori_data = ori_data.reshape(self._prefill_pp_size, -1)
|
|
num_groups = max(
|
|
1,
|
|
len(ori_data[0]) // num_kv_head
|
|
) # The number of redundant copies for each KV head within the PP stage
|
|
rand_group_index = rand.sample(range(num_groups), \
|
|
(max(self._decode_tp_size // num_kv_head, 1))) # random choose a group
|
|
all_results = [
|
|
self._get_remote_tp_ranks(ori_data[pp_index], rand_group_index,
|
|
num_groups)
|
|
for pp_index in range(self._prefill_pp_size)
|
|
]
|
|
for group_index in range(len(all_results[0])):
|
|
group = []
|
|
for pp_index in range(self._prefill_pp_size):
|
|
group.extend(all_results[pp_index][group_index])
|
|
sampled_nums.append(group)
|
|
return sampled_nums
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def zmq_ctx(socket_type: Any,
|
|
addr: str) -> Iterator[zmq.Socket]: # type: ignore
|
|
"""Context manager for a ZMQ socket"""
|
|
|
|
if socket_type not in (zmq.ROUTER, zmq.REQ, zmq.DEALER): # type: ignore
|
|
raise ValueError(f"Unexpected socket type: {socket_type}")
|
|
|
|
ctx: Optional[zmq.Context] = None # type: ignore
|
|
try:
|
|
ctx = zmq.Context() # type: ignore
|
|
yield make_zmq_socket(ctx=ctx,
|
|
path=addr,
|
|
socket_type=socket_type,
|
|
bind=socket_type == zmq.ROUTER) # type: ignore
|
|
finally:
|
|
if ctx is not None:
|
|
ctx.destroy(linger=0)
|
|
|
|
|
|
def group_concurrent_contiguous(
|
|
src: List[int], dst: List[int]
|
|
) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
|
|
"""Vectorised NumPy implementation."""
|
|
src_indices: npt.NDArray[np.int64] = np.array(src, dtype=np.int64)
|
|
dst_indices: npt.NDArray[np.int64] = np.array(dst, dtype=np.int64)
|
|
|
|
if src_indices.size == 0:
|
|
return [], []
|
|
|
|
brk = np.where((np.diff(src_indices) != 1)
|
|
| (np.diff(dst_indices) != 1))[0] + 1
|
|
src_groups = np.split(src_indices, brk)
|
|
dst_groups = np.split(dst_indices, brk)
|
|
|
|
src_groups = [g.tolist() for g in src_groups]
|
|
dst_groups = [g.tolist() for g in dst_groups]
|
|
|
|
return src_groups, dst_groups
|
|
|
|
|
|
def string_to_int64_hash(input_str):
|
|
"""
|
|
Hash the string using SHA-256 and convert it into an int64 integer.
|
|
"""
|
|
hashed_bytes = hashlib.sha256(input_str.encode("utf-8")).digest()
|
|
trunked_bytes = hashed_bytes[:8]
|
|
uint64_value = struct.unpack("<Q", trunked_bytes)[0]
|
|
return uint64_value
|
|
|
|
|
|
def ensure_zmq_send(
|
|
socket: zmq.Socket, # type: ignore
|
|
data: bytes,
|
|
max_retries: int = 3):
|
|
retries_left = max_retries
|
|
while True:
|
|
try:
|
|
socket.send(data)
|
|
return
|
|
except zmq.ZMQError as e: # type: ignore
|
|
retries_left -= 1
|
|
if retries_left > 0:
|
|
logger.warning(
|
|
f"Send failed: {e}, retrying... ({retries_left} "
|
|
"attempts left)")
|
|
time.sleep(0.1)
|
|
else:
|
|
logger.error(f"Send failed after all retries: {e}")
|
|
raise RuntimeError(f"Failed to send data after {max_retries} "
|
|
f"retries: {e}")
|
|
|
|
|
|
def ensure_zmq_recv(
|
|
socket: zmq.Socket, # type: ignore
|
|
poller: zmq.Poller, # type: ignore
|
|
timeout: float = 1.0,
|
|
max_retries: int = 3) -> bytes:
|
|
retries_left = max_retries
|
|
while True:
|
|
try:
|
|
if dict(poller.poll(int(timeout * 1000))): # milliseconds
|
|
data = socket.recv()
|
|
return data
|
|
else:
|
|
raise zmq.ZMQError("Receive timeout") # type: ignore
|
|
except zmq.ZMQError as e: # type: ignore
|
|
retries_left -= 1
|
|
if retries_left > 0:
|
|
logger.warning(f"Receive failed: {e}, retrying... "
|
|
f"({retries_left} attempts left)")
|
|
time.sleep(0.1)
|
|
else:
|
|
logger.error(f"Receive failed after all retries: {e}")
|
|
raise RuntimeError(
|
|
f"Failed to receive data after {max_retries} "
|
|
f"retries: {e}")
|