Files
sglang/python/sglang/srt/disaggregation/mooncake/conn.py

1701 lines
70 KiB
Python

from __future__ import annotations
import asyncio
import concurrent.futures
import ctypes
import dataclasses
import logging
import os
import queue
import socket
import struct
import threading
import time
from collections import defaultdict
from functools import cache
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import numpy.typing as npt
import requests
import zmq
from aiohttp import web
from sglang.srt.disaggregation.base.conn import (
BaseKVBootstrapServer,
BaseKVManager,
BaseKVReceiver,
BaseKVSender,
KVArgs,
KVPoll,
)
from sglang.srt.disaggregation.common.utils import (
FastQueue,
group_concurrent_contiguous,
)
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.distributed import get_pp_group
from sglang.srt.layers.dp_attention import (
get_attention_dp_rank,
get_attention_dp_size,
get_attention_tp_rank,
get_attention_tp_size,
)
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
format_tcp_address,
get_bool_env_var,
get_free_port,
get_int_env_var,
get_ip,
get_local_ip_auto,
is_valid_ipv6_address,
maybe_wrap_ipv6_address,
)
logger = logging.getLogger(__name__)
class KVTransferError(Exception):
def __init__(self, bootstrap_room: int, failure_reason: str):
super().__init__(failure_reason)
self.bootstrap_room = bootstrap_room
self.failure_reason = failure_reason
def __str__(self):
return f"KVTransferError(bootstrap_room={self.bootstrap_room}): {self.failure_reason}"
# prefill
@dataclasses.dataclass
class TransferKVChunk:
room: int
prefill_kv_indices: npt.NDArray[np.int32]
index_slice: slice
is_last: bool
prefill_aux_index: Optional[int]
# decode
@dataclasses.dataclass
class TransferInfo:
room: int
endpoint: str
dst_port: int
mooncake_session_id: str
dst_kv_indices: npt.NDArray[np.int32]
dst_aux_index: int
required_dst_info_num: int
is_dummy: bool
@classmethod
def from_zmq(cls, msg: List[bytes]):
if msg[4] == b"" and msg[5] == b"":
is_dummy = True
dst_kv_indices = np.array([], dtype=np.int32)
dst_aux_index = None
else:
dst_kv_indices = np.frombuffer(msg[4], dtype=np.int32)
dst_aux_index = int(msg[5].decode("ascii"))
is_dummy = False
return cls(
room=int(msg[0].decode("ascii")),
endpoint=msg[1].decode("ascii"),
dst_port=int(msg[2].decode("ascii")),
mooncake_session_id=msg[3].decode("ascii"),
dst_kv_indices=dst_kv_indices,
dst_aux_index=dst_aux_index,
required_dst_info_num=int(msg[6].decode("ascii")),
is_dummy=is_dummy,
)
# decode
@dataclasses.dataclass
class KVArgsRegisterInfo:
room: str
endpoint: str
dst_port: int
mooncake_session_id: str
dst_kv_ptrs: list[int]
dst_aux_ptrs: list[int]
dst_tp_rank: int
dst_attn_tp_size: int
dst_kv_item_len: int
@classmethod
def from_zmq(cls, msg: List[bytes]):
return cls(
room=str(msg[0].decode("ascii")),
endpoint=msg[1].decode("ascii"),
dst_port=int(msg[2].decode("ascii")),
mooncake_session_id=msg[3].decode("ascii"),
dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])),
dst_aux_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
dst_tp_rank=int(msg[6].decode("ascii")),
dst_attn_tp_size=int(msg[7].decode("ascii")),
dst_kv_item_len=int(msg[8].decode("ascii")),
)
class AuxDataCodec:
"""Handles serialization and deserialization of auxiliary data buffers"""
@staticmethod
def serialize_data_from_buffer(src_addr, data_length):
"""Serialize data from memory buffer to bytes"""
buffer = (ctypes.c_byte * data_length).from_address(src_addr)
return bytes(buffer)
@staticmethod
def deserialize_data_to_buffer(kv_args, buffer_index, aux_index, data):
"""Deserialize bytes into target memory buffer"""
dst_aux_ptr = kv_args.aux_data_ptrs[buffer_index]
item_len = kv_args.aux_item_lens[buffer_index]
dst_addr = dst_aux_ptr + item_len * aux_index
buffer = (ctypes.c_byte * len(data)).from_address(dst_addr)
buffer[:] = data
return
class MooncakeKVManager(BaseKVManager):
AUX_DATA_HEADER = b"AUX_DATA"
def __init__(
self,
args: KVArgs,
disaggregation_mode: DisaggregationMode,
server_args: ServerArgs,
is_mla_backend: Optional[bool] = False,
):
self.kv_args = args
self.local_ip = get_local_ip_auto()
self.is_mla_backend = is_mla_backend
self.disaggregation_mode = disaggregation_mode
self.init_engine()
# for p/d multi node infer
self.bootstrap_host = server_args.host
self.bootstrap_port = server_args.disaggregation_bootstrap_port
self.dist_init_addr = server_args.dist_init_addr
self.attn_tp_size = get_attention_tp_size()
self.attn_tp_rank = get_attention_tp_rank()
self.attn_dp_size = get_attention_dp_size()
self.attn_dp_rank = get_attention_dp_rank()
self.system_dp_size = (
1 if server_args.enable_dp_attention else server_args.dp_size
)
self.system_dp_rank = (
self.kv_args.system_dp_rank if self.kv_args.system_dp_rank else 0
)
self.pp_size = server_args.pp_size
self.pp_rank = self.kv_args.pp_rank
self.request_status: Dict[int, KVPoll] = {}
self.rank_port = None
self.server_socket = zmq.Context().socket(zmq.PULL)
if is_valid_ipv6_address(self.local_ip):
self.server_socket.setsockopt(zmq.IPV6, 1)
self.register_buffer_to_engine()
if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
self.start_prefill_thread()
self._register_to_bootstrap()
self.session_failures = defaultdict(int)
self.failed_sessions = set()
self.session_lock = threading.Lock()
self.pp_group = get_pp_group()
# Determine the number of threads to use for kv sender
cpu_count = os.cpu_count()
transfer_thread_pool_size = get_int_env_var(
"SGLANG_DISAGGREGATION_THREAD_POOL_SIZE",
min(max(4, int(0.75 * cpu_count) // 8), 12),
)
transfer_queue_size = get_int_env_var("SGLANG_DISAGGREGATION_QUEUE_SIZE", 4)
self.transfer_queues: List[FastQueue] = [
FastQueue() for _ in range(transfer_queue_size)
]
assert transfer_thread_pool_size >= transfer_queue_size, (
f"The environment variable SGLANG_DISAGGREGATION_THREAD_POOL_SIZE={transfer_thread_pool_size} must be "
f"greater than or equal to SGLANG_DISAGGREGATION_QUEUE_SIZE={transfer_queue_size}."
)
self.executors = [
concurrent.futures.ThreadPoolExecutor(
transfer_thread_pool_size // transfer_queue_size
)
for _ in range(transfer_queue_size)
]
for queue, executor in zip(self.transfer_queues, self.executors):
threading.Thread(
target=self.transfer_worker, args=(queue, executor), daemon=True
).start()
# If a timeout happens on the prefill side, it means prefill instances
# fail to receive the KV indices from the decode instance of this request.
# These timeout requests should be aborted to release the tree cache.
self.bootstrap_timeout = get_int_env_var(
"SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", 300
)
self.enable_custom_mem_pool = get_bool_env_var(
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
)
elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.heartbeat_failures = {}
self.session_pool = defaultdict(requests.Session)
self.session_pool_lock = threading.Lock()
self.addr_to_rooms_tracker = defaultdict(set)
self.connection_lock = threading.Lock()
self.required_prefill_response_num_table: Dict[int, int] = {}
self.prefill_response_tracker: Dict[int, Set[int]] = defaultdict(set)
# Heartbeat interval should be at least 2 seconds
self.heartbeat_interval = max(
float(os.getenv("SGLANG_DISAGGREGATION_HEARTBEAT_INTERVAL", 5.0)), 2.0
)
# Heartbeat failure should be at least 1
self.max_failures = max(
get_int_env_var("SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE", 2), 1
)
self.start_decode_thread()
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
self.prefill_attn_tp_size_table: Dict[str, int] = {}
self.prefill_dp_size_table: Dict[str, int] = {}
self.prefill_pp_size_table: Dict[str, int] = {}
# If a timeout happens on the decode side, it means decode instances
# fail to receive the KV Cache transfer done signal after bootstrapping.
# These timeout requests should be aborted to release the tree cache.
self.waiting_timeout = get_int_env_var(
"SGLANG_DISAGGREGATION_WAITING_TIMEOUT", 300
)
else:
raise ValueError(
f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
)
self.failure_records: Dict[int, str] = {}
self.failure_lock = threading.Lock()
def init_engine(self):
self.engine = MooncakeTransferEngine(
hostname=self.local_ip,
gpu_id=self.kv_args.gpu_id,
ib_device=self.kv_args.ib_device,
)
def register_buffer_to_engine(self):
# Batch register KV data buffers
if self.kv_args.kv_data_ptrs and self.kv_args.kv_data_lens:
self.engine.batch_register(
self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens
)
# Batch register auxiliary data buffers
if self.kv_args.aux_data_ptrs and self.kv_args.aux_data_lens:
self.engine.batch_register(
self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
)
@cache
def _connect(self, endpoint: str, is_ipv6: bool = False):
socket = zmq.Context().socket(zmq.PUSH)
if is_ipv6:
socket.setsockopt(zmq.IPV6, 1)
socket.connect(endpoint)
return socket
def _transfer_data(self, mooncake_session_id, transfer_blocks):
if not transfer_blocks:
return 0
src_addrs, dst_addrs, lengths = zip(*transfer_blocks)
return self.engine.batch_transfer_sync(
mooncake_session_id, list(src_addrs), list(dst_addrs), list(lengths)
)
def send_kvcache(
self,
mooncake_session_id: str,
prefill_kv_indices: npt.NDArray[np.int32],
dst_kv_ptrs: list[int],
dst_kv_indices: npt.NDArray[np.int32],
executor: concurrent.futures.ThreadPoolExecutor,
):
# Group by indices
prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(
prefill_kv_indices, dst_kv_indices
)
layers_params = None
# pp is not supported on the decode side yet
start_layer = self.kv_args.prefill_start_layer
end_layer = start_layer + len(self.kv_args.kv_data_ptrs)
if self.is_mla_backend:
src_kv_ptrs = self.kv_args.kv_data_ptrs
layers_per_pp_stage = len(src_kv_ptrs)
dst_kv_ptrs = dst_kv_ptrs[start_layer:end_layer]
kv_item_len = self.kv_args.kv_item_lens[0]
layers_params = [
(
src_kv_ptrs[layer_id],
dst_kv_ptrs[layer_id],
kv_item_len,
)
for layer_id in range(layers_per_pp_stage)
]
else:
num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
dst_num_total_layers = num_kv_layers * self.pp_size
src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
layers_per_pp_stage = len(src_k_ptrs)
dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
dst_v_ptrs = dst_kv_ptrs[
dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
]
kv_item_len = self.kv_args.kv_item_lens[0]
layers_params = [
(
src_k_ptrs[layer_id],
dst_k_ptrs[layer_id],
kv_item_len,
)
for layer_id in range(layers_per_pp_stage)
] + [
(
src_v_ptrs[layer_id],
dst_v_ptrs[layer_id],
kv_item_len,
)
for layer_id in range(layers_per_pp_stage)
]
assert layers_params is not None
def set_transfer_blocks(
src_ptr: int, dst_ptr: int, item_len: int
) -> List[Tuple[int, int, int]]:
transfer_blocks = []
for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
src_addr = src_ptr + int(prefill_index[0]) * item_len
dst_addr = dst_ptr + int(decode_index[0]) * item_len
length = item_len * len(prefill_index)
transfer_blocks.append((src_addr, dst_addr, length))
return transfer_blocks
# Worker function for processing a single layer
def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int:
transfer_blocks = set_transfer_blocks(src_ptr, dst_ptr, item_len)
return self._transfer_data(mooncake_session_id, transfer_blocks)
# Worker function for processing all layers in a batch
def process_layers(layers_params: List[Tuple[int, int, int]]) -> int:
transfer_blocks = []
for src_ptr, dst_ptr, item_len in layers_params:
transfer_blocks.extend(set_transfer_blocks(src_ptr, dst_ptr, item_len))
return self._transfer_data(mooncake_session_id, transfer_blocks)
if self.enable_custom_mem_pool:
futures = [
executor.submit(
process_layer,
src_ptr,
dst_ptr,
item_len,
)
for (src_ptr, dst_ptr, item_len) in layers_params
]
for future in concurrent.futures.as_completed(futures):
status = future.result()
if status != 0:
for f in futures:
f.cancel()
return status
else:
# Combining all layers' params in one batch transfer is more efficient
# compared to using multiple threads
return process_layers(layers_params)
return 0
def send_kvcache_slice(
self,
mooncake_session_id: str,
prefill_kv_indices: npt.NDArray[np.int64],
dst_kv_ptrs: list[int],
dst_kv_indices: npt.NDArray[np.int64],
dst_tp_rank: int,
dst_attn_tp_size: int,
dst_kv_item_len: int,
executor: concurrent.futures.ThreadPoolExecutor,
):
"""
Sends KV cache slices from this Prefill rank to a target Decode rank,
supporting generic M-to-N TP size configurations.
NOTE: This implementation calls the transfer engine for each token slot within
each page to ensure correctness for any page_size and head-slicing configuration.
This may introduce performance overhead (increased TTFT) for long sequences.
"""
# Extract configuration
local_tp_rank_in_group = self.kv_args.engine_rank % self.attn_tp_size
src_kv_item_len = self.kv_args.kv_item_lens[0]
dst_tp_rank_in_group = dst_tp_rank % dst_attn_tp_size
num_kv_heads = self.kv_args.kv_head_num
num_layers = len(self.kv_args.kv_data_ptrs)
page_size = self.kv_args.page_size
# Calculate head distribution
src_heads_per_rank = num_kv_heads
dst_heads_per_rank = num_kv_heads * self.attn_tp_size // dst_attn_tp_size
bytes_per_head_slice_to_send = (
dst_kv_item_len // page_size // dst_heads_per_rank
)
# Determine slicing parameters based on TP configuration
if self.attn_tp_size > dst_attn_tp_size:
# Send KVCache from multiple prefill instances to 1 decode instance
src_head_start_offset = 0
num_heads_to_send = src_heads_per_rank
dst_head_start_offset = local_tp_rank_in_group * src_heads_per_rank
else:
# Send KVCache from 1 prefill instance to multiple decode instances
src_head_start_offset = dst_tp_rank_in_group * dst_heads_per_rank
num_heads_to_send = dst_heads_per_rank
dst_head_start_offset = 0
# pp is not supported on the decode side yet
num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
dst_num_total_layers = num_kv_layers * self.pp_size
src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
layers_per_pp_stage = len(src_k_ptrs)
start_layer = self.pp_rank * layers_per_pp_stage
end_layer = start_layer + layers_per_pp_stage
dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
dst_v_ptrs = dst_kv_ptrs[
dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
]
# Calculate precise byte offset and length for the sub-slice within the token
src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send
dst_head_slice_offset = dst_head_start_offset * bytes_per_head_slice_to_send
heads_bytes_per_token_to_send = num_heads_to_send * bytes_per_head_slice_to_send
# Sanity check: The data sub-slice to be sent should fit into the dst buffer.
# This means heads_bytes_per_token_to_send <= (dst_kv_item_len // page_size)
if heads_bytes_per_token_to_send > (dst_kv_item_len // page_size):
logger.error(
f"[{mooncake_session_id}] slice size ({heads_bytes_per_token_to_send}) exceeds "
f"target token slot size ({dst_kv_item_len // page_size})"
)
return -1
layers_params = [
(
src_k_ptrs[layer_id],
dst_k_ptrs[layer_id],
src_kv_item_len,
dst_kv_item_len,
src_head_slice_offset,
dst_head_slice_offset,
heads_bytes_per_token_to_send,
)
for layer_id in range(layers_per_pp_stage)
] + [
(
src_v_ptrs[layer_id],
dst_v_ptrs[layer_id],
src_kv_item_len,
dst_kv_item_len,
src_head_slice_offset,
dst_head_slice_offset,
heads_bytes_per_token_to_send,
)
for layer_id in range(layers_per_pp_stage)
]
def process_layer_tp_aware(layer_params):
(
src_ptr,
dst_ptr,
src_item_len,
dst_item_len,
src_head_slice_offset,
dst_head_slice_offset,
heads_bytes_per_token_to_send,
) = layer_params
src_addr_list = []
dst_addr_list = []
length_list = []
# Calculate strides for a single token slot
bytes_per_token_on_prefill = src_item_len // page_size
bytes_per_token_on_decode = dst_item_len // page_size
for i in range(len(prefill_kv_indices)):
prefill_page_idx = int(prefill_kv_indices[i])
decode_page_idx = int(dst_kv_indices[i])
# Get the starting addresses for the current src and dst pages
src_page_start_addr = src_ptr + prefill_page_idx * src_item_len
dst_page_start_addr = dst_ptr + decode_page_idx * dst_item_len
# Iterate through each valid token slot within the current page
for token_slot_in_page in range(page_size):
# Calculate the start address of the current token slot
src_token_slot_start_addr = (
src_page_start_addr
+ token_slot_in_page * bytes_per_token_on_prefill
)
dst_token_slot_start_addr = (
dst_page_start_addr
+ token_slot_in_page * bytes_per_token_on_decode
)
# Calculate final src and dst addresses by applying head-slice offsets
src_slice_addr = src_token_slot_start_addr + src_head_slice_offset
dst_slice_addr = dst_token_slot_start_addr + dst_head_slice_offset
src_addr_list.append(src_slice_addr)
dst_addr_list.append(dst_slice_addr)
length_list.append(heads_bytes_per_token_to_send)
return self.engine.batch_transfer_sync(
mooncake_session_id, src_addr_list, dst_addr_list, length_list
)
futures = [
executor.submit(
process_layer_tp_aware,
layer_params,
)
for layer_params in layers_params
]
for future in concurrent.futures.as_completed(futures):
status = future.result()
if status != 0:
for f in futures:
f.cancel()
return status
return 0
def send_aux(
self,
req: TransferInfo,
prefill_aux_index: int,
dst_aux_ptrs: list[int],
):
# TODO(shangming): Fix me when nvlink_transport of Mooncake is bug-free
if self.enable_custom_mem_pool:
return self.send_aux_tcp(req, prefill_aux_index, dst_aux_ptrs)
transfer_blocks = []
prefill_aux_ptrs = self.kv_args.aux_data_ptrs
prefill_aux_item_lens = self.kv_args.aux_item_lens
for i, dst_aux_ptr in enumerate(dst_aux_ptrs):
length = prefill_aux_item_lens[i]
src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index
dst_addr = dst_aux_ptrs[i] + length * req.dst_aux_index
transfer_blocks.append((src_addr, dst_addr, length))
return self._transfer_data(req.mooncake_session_id, transfer_blocks)
def send_aux_tcp(
self,
req: TransferInfo,
prefill_aux_index: int,
dst_aux_ptrs: list[int],
):
prefill_aux_ptrs = self.kv_args.aux_data_ptrs
prefill_aux_item_lens = self.kv_args.aux_item_lens
for i in range(len(prefill_aux_ptrs)):
length = prefill_aux_item_lens[i]
src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index
data = AuxDataCodec.serialize_data_from_buffer(src_addr, length)
self.send_aux_data_to_endpoint(
remote=req.endpoint,
dst_port=req.dst_port,
room=req.room,
buffer_index=i,
aux_index=req.dst_aux_index,
data=data,
)
return 0
def send_aux_data_to_endpoint(
self,
remote: str,
dst_port: int,
room: int,
buffer_index: int,
aux_index: int,
data: bytes,
):
socket = self._connect(
format_tcp_address(remote, dst_port), is_ipv6=is_valid_ipv6_address(remote)
)
socket.send_multipart(
[
MooncakeKVManager.AUX_DATA_HEADER,
str(room).encode("ascii"),
str(buffer_index).encode("ascii"),
str(aux_index).encode("ascii"),
struct.pack(">I", len(data)),
data,
]
)
def sync_status_to_decode_endpoint(
self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int
):
self._connect(
format_tcp_address(remote, dst_port), is_ipv6=is_valid_ipv6_address(remote)
).send_multipart(
[
str(room).encode("ascii"),
str(status).encode("ascii"),
str(prefill_rank).encode("ascii"),
]
)
def transfer_worker(
self, queue: FastQueue, executor: concurrent.futures.ThreadPoolExecutor
):
while True:
try:
kv_chunk: TransferKVChunk = queue.get()
reqs_to_be_processed = (
self.transfer_infos[kv_chunk.room].values()
if kv_chunk.room in self.transfer_infos
else []
)
polls = []
dst_ranks_infos = []
local_rank = self.attn_tp_rank * self.pp_size + self.pp_rank
for req in reqs_to_be_processed:
if not req.is_dummy:
# Early exit if the request has failed
with self.session_lock:
if req.mooncake_session_id in self.failed_sessions:
self.record_failure(
kv_chunk.room,
f"Decode instance could be dead, remote mooncake session {req.mooncake_session_id} is not alive",
)
self.update_status(kv_chunk.room, KVPoll.Failed)
self.sync_status_to_decode_endpoint(
req.endpoint,
req.dst_port,
req.room,
KVPoll.Failed,
local_rank,
)
break
chunked_dst_kv_indice = req.dst_kv_indices[kv_chunk.index_slice]
# NOTE: This is temporarily a workaround to deal with the case where the prefill_kv_indices
# is mismatched with the dst_kv_indices when page size > 1, this should never happen.
if len(chunked_dst_kv_indice) < len(
kv_chunk.prefill_kv_indices
):
logger.warning(
f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}"
)
kv_chunk.prefill_kv_indices = kv_chunk.prefill_kv_indices[
: len(chunked_dst_kv_indice)
]
target_rank_registration_info: KVArgsRegisterInfo = (
self.decode_kv_args_table[req.mooncake_session_id]
)
if self.is_mla_backend or (
self.attn_tp_size
== target_rank_registration_info.dst_attn_tp_size
):
ret = self.send_kvcache(
req.mooncake_session_id,
kv_chunk.prefill_kv_indices,
target_rank_registration_info.dst_kv_ptrs,
chunked_dst_kv_indice,
executor,
)
else:
ret = self.send_kvcache_slice(
req.mooncake_session_id,
kv_chunk.prefill_kv_indices,
target_rank_registration_info.dst_kv_ptrs,
chunked_dst_kv_indice,
target_rank_registration_info.dst_tp_rank,
target_rank_registration_info.dst_attn_tp_size,
target_rank_registration_info.dst_kv_item_len,
executor,
)
if ret != 0:
with self.session_lock:
self.session_failures[req.mooncake_session_id] += 1
# Failures should never happen if the session is not dead, if the session fails once, mark it as failed
if self.session_failures[req.mooncake_session_id] >= 1:
self.failed_sessions.add(req.mooncake_session_id)
logger.error(
f"Session {req.mooncake_session_id} failed."
)
self.record_failure(
kv_chunk.room,
f"Failed to send kv chunk of {kv_chunk.room} to {req.endpoint}:{req.dst_port}",
)
self.update_status(kv_chunk.room, KVPoll.Failed)
self.sync_status_to_decode_endpoint(
req.endpoint,
req.dst_port,
req.room,
KVPoll.Failed,
local_rank,
)
break
if kv_chunk.is_last:
if self.pp_group.is_last_rank:
# Only the last chunk we need to send the aux data
ret = self.send_aux(
req,
kv_chunk.prefill_aux_index,
target_rank_registration_info.dst_aux_ptrs,
)
polls.append(True if ret == 0 else False)
dst_ranks_infos.append(
(req.endpoint, req.dst_port, req.room)
)
# Only sync status when all the dst ranks have received the kvcache
if len(polls) == req.required_dst_info_num:
status = KVPoll.Success if all(polls) else KVPoll.Failed
self.update_status(req.room, status)
for endpoint, dst_port, room in dst_ranks_infos:
self.sync_status_to_decode_endpoint(
endpoint, dst_port, room, status, local_rank
)
else:
# Dummy request means the decode instance is not used, so its status can be marked as success directly
# Dummy request does not need to sync status to decode endpoint
if kv_chunk.is_last and req.room in self.request_status:
self.update_status(req.room, KVPoll.Success)
if (
kv_chunk.room not in self.request_status
or self.check_status(kv_chunk.room) == KVPoll.Success
):
if kv_chunk.room in self.transfer_infos:
self.transfer_infos.pop(kv_chunk.room)
except Exception as e:
# NOTE(shangming): Remove this when we make sure the transfer thread is bug-free
raise RuntimeError(
f"Transfer thread failed because of {e}. Prefill instance with bootstrap_port={self.bootstrap_port} is dead."
)
def _bind_server_socket(self):
self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
def start_prefill_thread(self):
self.rank_port = get_free_port()
self._bind_server_socket()
def bootstrap_thread():
"""This thread recvs pre-alloc notification from the decode engine"""
# KVPoll.Bootstrapping -> KVPoll.WaitingForInput
while True:
waiting_req_bytes = self.server_socket.recv_multipart()
room = waiting_req_bytes[0].decode("ascii")
mooncake_session_id = waiting_req_bytes[3].decode("ascii")
if room == "None":
self.decode_kv_args_table[mooncake_session_id] = (
KVArgsRegisterInfo.from_zmq(waiting_req_bytes)
)
with self.session_lock:
if mooncake_session_id in self.failed_sessions:
self.failed_sessions.remove(mooncake_session_id)
if mooncake_session_id in self.session_failures:
del self.session_failures[mooncake_session_id]
logger.debug(
f"Register KVArgs from {mooncake_session_id} successfully"
)
continue
else:
required_dst_info_num = int(waiting_req_bytes[6].decode("ascii"))
room = int(room)
if room not in self.transfer_infos:
self.transfer_infos[room] = {}
self.transfer_infos[room][mooncake_session_id] = (
TransferInfo.from_zmq(waiting_req_bytes)
)
# NOTE: after bootstrapping we can mark the req as waiting for input
if len(self.transfer_infos[room]) == required_dst_info_num:
self.update_status(room, KVPoll.WaitingForInput)
threading.Thread(target=bootstrap_thread).start()
def _handle_aux_data(self, msg: List[bytes]):
"""Handle AUX_DATA messages received by the decode thread."""
room = int(msg[1].decode("ascii"))
buffer_index = int(msg[2].decode("ascii"))
aux_index = int(msg[3].decode("ascii"))
data_length = struct.unpack(">I", msg[4])[0]
data = msg[5]
if len(data) != data_length:
logger.error(f"AUX_DATA length mismatch for bootstrap_room {room}")
return
AuxDataCodec.deserialize_data_to_buffer(
self.kv_args, buffer_index, aux_index, data
)
logger.debug(
f"Received AUX_DATA for bootstrap_room {room} with length:{len(data)}"
)
def start_decode_thread(self):
self.rank_port = get_free_port()
self._bind_server_socket()
def decode_thread():
while True:
msg = self.server_socket.recv_multipart()
if msg[0] == MooncakeKVManager.AUX_DATA_HEADER:
self._handle_aux_data(msg)
continue
(bootstrap_room, status, prefill_rank) = msg
status = int(status.decode("ascii"))
bootstrap_room = int(bootstrap_room.decode("ascii"))
prefill_rank = int(prefill_rank.decode("ascii"))
if status == KVPoll.Success:
if bootstrap_room in self.request_status:
self.prefill_response_tracker[bootstrap_room].add(prefill_rank)
expected_response_num = (
self.required_prefill_response_num_table[bootstrap_room]
)
arrived_response_num = len(
self.prefill_response_tracker[bootstrap_room]
)
if arrived_response_num == expected_response_num:
self.update_status(bootstrap_room, KVPoll.Success)
elif status == KVPoll.Failed:
self.record_failure(
bootstrap_room,
f"Failed to get kvcache from prefill instance, it might be dead",
)
self.update_status(bootstrap_room, status)
def heartbeat_checker():
while True:
time.sleep(self.heartbeat_interval)
with self.connection_lock:
addresses = list(self.prefill_dp_size_table.keys())
for bootstrap_addr in addresses:
session = None
try:
with self.session_pool_lock:
session = self.session_pool[bootstrap_addr]
response = session.get(
f"http://{bootstrap_addr}/health",
timeout=(2, 3),
headers={"Connection": "keep-alive"},
)
if response.status_code == 200:
self.heartbeat_failures[bootstrap_addr] = 0
current_rooms = self.addr_to_rooms_tracker[
bootstrap_addr
].copy()
for bootstrap_room in current_rooms:
# Remove KVPoll.Success requests from the tracker
if bootstrap_room not in self.request_status:
self.addr_to_rooms_tracker[bootstrap_addr].discard(
bootstrap_room
)
else:
logger.info(
f"Attempting to reconnect to {bootstrap_addr}..."
)
self.heartbeat_failures[bootstrap_addr] = (
self.heartbeat_failures.get(bootstrap_addr, 0) + 1
)
with self.session_pool_lock:
if bootstrap_addr in self.session_pool:
del self.session_pool[bootstrap_addr]
except Exception:
logger.info(f"Attempting to reconnect to {bootstrap_addr}...")
self.heartbeat_failures[bootstrap_addr] = (
self.heartbeat_failures.get(bootstrap_addr, 0) + 1
)
if (
self.heartbeat_failures.get(bootstrap_addr, 0)
>= self.max_failures
):
self._handle_node_failure(bootstrap_addr)
with self.session_pool_lock:
if bootstrap_addr in self.session_pool:
del self.session_pool[bootstrap_addr]
threading.Thread(target=decode_thread).start()
threading.Thread(target=heartbeat_checker).start()
def add_transfer_request(
self,
bootstrap_room: int,
kv_indices: npt.NDArray[np.int32],
index_slice: slice,
is_last: bool,
aux_index: Optional[int] = None,
):
assert self.disaggregation_mode == DisaggregationMode.PREFILL
assert not is_last or (is_last and aux_index is not None)
if (
bootstrap_room not in self.request_status
or self.check_status(bootstrap_room) == KVPoll.Failed
):
logger.debug(
"Request with bootstrap_room=%s already failed", bootstrap_room
)
return
if bootstrap_room not in self.transfer_infos:
# This means that the current rank is a dummy rank for this request,
# and it has already been marked as success, so there is no need to
# add further chunks into the transfer queue.
return
# NOTE(shangming): sharding according to the dst_infos to make sure
# requests with the same dst_sessions will be added into the same
# queue, which enables early abort with failed sessions.
dst_infos = self.transfer_infos[bootstrap_room].keys()
session_port_sum = sum(int(session.rsplit(":", 1)[1]) for session in dst_infos)
shard_idx = session_port_sum % len(self.transfer_queues)
self.transfer_queues[shard_idx].put(
TransferKVChunk(
room=bootstrap_room,
prefill_kv_indices=kv_indices,
index_slice=index_slice,
is_last=is_last,
prefill_aux_index=aux_index,
)
)
def check_status(self, bootstrap_room: int):
return self.request_status[bootstrap_room]
def update_status(self, bootstrap_room: int, status: KVPoll):
if bootstrap_room not in self.request_status:
self.request_status[bootstrap_room] = status
else:
# NOTE: status is only allowed to be incremented unless it is KVPoll.Failed
if status == KVPoll.Failed:
self.request_status[bootstrap_room] = KVPoll.Failed
else:
self.request_status[bootstrap_room] = max(
self.request_status[bootstrap_room], status
)
def record_failure(self, bootstrap_room: int, failure_reason: str):
with self.failure_lock:
self.failure_records[bootstrap_room] = failure_reason
def get_session_id(self):
return self.engine.get_session_id()
def _register_to_bootstrap(self):
"""Register KVSender to bootstrap server via HTTP POST."""
if self.dist_init_addr:
# multi node case: bootstrap server's host is dist_init_addr
if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6]
if self.dist_init_addr.endswith("]"):
host = self.dist_init_addr
else:
host, _ = self.dist_init_addr.rsplit(":", 1)
else:
host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0])
else:
# single node case: bootstrap server's host is same as http server's host
host = self.bootstrap_host
host = maybe_wrap_ipv6_address(host)
bootstrap_server_url = f"{host}:{self.bootstrap_port}"
url = f"http://{bootstrap_server_url}/route"
payload = {
"role": "Prefill",
"attn_tp_size": self.attn_tp_size,
"attn_tp_rank": self.attn_tp_rank,
"attn_dp_size": self.attn_dp_size,
"attn_dp_rank": self.attn_dp_rank,
"pp_size": self.pp_size,
"pp_rank": self.pp_rank,
"system_dp_size": self.system_dp_size,
"system_dp_rank": self.system_dp_rank,
"rank_ip": self.local_ip,
"rank_port": self.rank_port,
}
try:
response = requests.put(url, json=payload, timeout=5)
if response.status_code == 200:
logger.debug("Prefill successfully registered to bootstrap server.")
else:
logger.error(
f"Prefill instance failed to connect to bootstrap server: {response.status_code}, {response.text}"
)
except Exception as e:
logger.error(
f"Prefill instance failed to register to bootstrap server: {e}"
)
def _handle_node_failure(self, failed_bootstrap_addr):
with self.connection_lock:
keys_to_remove = [
k for k in self.connection_pool if k.startswith(failed_bootstrap_addr)
]
for k in keys_to_remove:
del self.connection_pool[k]
if failed_bootstrap_addr in self.prefill_attn_tp_size_table:
del self.prefill_attn_tp_size_table[failed_bootstrap_addr]
if failed_bootstrap_addr in self.prefill_dp_size_table:
del self.prefill_dp_size_table[failed_bootstrap_addr]
if failed_bootstrap_addr in self.prefill_pp_size_table:
del self.prefill_pp_size_table[failed_bootstrap_addr]
possible_affected_rooms = self.addr_to_rooms_tracker.get(
failed_bootstrap_addr, []
)
if failed_bootstrap_addr in self.addr_to_rooms_tracker:
del self.addr_to_rooms_tracker[failed_bootstrap_addr]
# Report the requests associated with the failed bootstrap addr and mark their status as KVPoll.Failed
affected_rooms = []
for room in possible_affected_rooms:
if (
room in self.request_status
and self.check_status(room) != KVPoll.Success
):
self.record_failure(
room,
f"Losing connection with prefill instance (bootstrap_addr: {failed_bootstrap_addr})",
)
self.update_status(room, KVPoll.Failed)
affected_rooms.append(room)
logger.error(
f"Losing connection with prefill instance (bootstrap_addr: {failed_bootstrap_addr}), {len(affected_rooms)} requests affected"
)
class MooncakeKVSender(BaseKVSender):
def __init__(
self,
mgr: MooncakeKVManager,
bootstrap_addr: str,
bootstrap_room: int,
dest_tp_ranks: List[int],
pp_rank: int,
):
self.kv_mgr = mgr
self.bootstrap_room = bootstrap_room
self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping)
self.aux_index = None
self.bootstrap_server_url = bootstrap_addr
self.conclude_state = None
self.init_time = time.time()
# inner state
self.curr_idx = 0
def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
self.num_kv_indices = num_kv_indices
self.aux_index = aux_index
def send(
self,
kv_indices: npt.NDArray[np.int32],
):
index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices))
self.curr_idx += len(kv_indices)
is_last = self.curr_idx == self.num_kv_indices
if not is_last:
self.kv_mgr.add_transfer_request(
self.bootstrap_room,
kv_indices,
index_slice,
False,
)
else:
self.kv_mgr.add_transfer_request(
self.bootstrap_room,
kv_indices,
index_slice,
True,
aux_index=self.aux_index,
)
def poll(self) -> KVPoll:
if self.conclude_state is None:
status = self.kv_mgr.check_status(self.bootstrap_room)
if status in (KVPoll.Success, KVPoll.Failed):
self.conclude_state = status
elif status == KVPoll.Bootstrapping:
if self.init_time is not None:
now = time.time()
elapsed = now - self.init_time
if elapsed >= self.kv_mgr.bootstrap_timeout:
logger.warning_once(
"Some requests timed out when bootstrapping, "
"which means prefill instances fail to receive the KV indices from the decode instance of this request. "
"If a greater mean TTFT is acceptable, you can 'export SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=600' (10 minutes) to relax the timeout condition. "
)
self.kv_mgr.record_failure(
self.bootstrap_room,
f"Request {self.bootstrap_room} timed out after {elapsed:.1f}s in KVPoll.Bootstrapping",
)
self.conclude_state = KVPoll.Failed
return KVPoll.Failed
return status
else:
return self.conclude_state
def clear(self) -> None:
if self.bootstrap_room in self.kv_mgr.request_status:
self.kv_mgr.request_status.pop(self.bootstrap_room)
def failure_exception(self):
# Explicitly set the status to failure since this request has failed in another rank
if self.conclude_state is None:
self.conclude_state = KVPoll.Failed
self.clear()
with self.kv_mgr.failure_lock:
failure_reason = self.kv_mgr.failure_records.pop(
self.bootstrap_room, "Failed due to an unknown reason from another rank"
)
raise KVTransferError(self.bootstrap_room, failure_reason)
def abort(self):
self.kv_mgr.record_failure(
self.bootstrap_room,
"Aborted by AbortReq.",
)
# Explicitly set the status to failure since this request has been aborted
self.conclude_state = KVPoll.Failed
class MooncakeKVReceiver(BaseKVReceiver):
_ctx = zmq.Context()
_socket_cache = {}
_socket_locks = {}
_global_lock = threading.Lock()
def __init__(
self,
mgr: MooncakeKVManager,
bootstrap_addr: str,
bootstrap_room: Optional[int] = None,
data_parallel_rank: Optional[int] = None,
):
self.bootstrap_room = bootstrap_room
self.bootstrap_addr = bootstrap_addr
self.kv_mgr = mgr
self.session_id = self.kv_mgr.get_session_id()
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
self.conclude_state = None
self.init_time = None
self.data_parallel_rank = data_parallel_rank
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
(
self.prefill_attn_tp_size,
self.prefill_dp_size,
self.prefill_pp_size,
) = self._get_prefill_parallel_info_from_server()
if (
self.prefill_attn_tp_size is None
or self.prefill_dp_size is None
or self.prefill_pp_size is None
):
self.kv_mgr.record_failure(
self.bootstrap_room,
f"Could not fetch prefill parallel info from bootstrap_addr: {self.bootstrap_addr}",
)
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)
return
else:
logger.debug(
f"Fetch prefill parallel info from [{self.bootstrap_addr}]: DP size:{self.prefill_dp_size}, TP size:{self.prefill_attn_tp_size} PP size:{self.prefill_pp_size}"
)
self.kv_mgr.prefill_attn_tp_size_table[self.bootstrap_addr] = (
self.prefill_attn_tp_size
)
self.kv_mgr.prefill_dp_size_table[self.bootstrap_addr] = (
self.prefill_dp_size
)
self.kv_mgr.prefill_pp_size_table[self.bootstrap_addr] = (
self.prefill_pp_size
)
else:
self.prefill_attn_tp_size = self.kv_mgr.prefill_attn_tp_size_table[
self.bootstrap_addr
]
self.prefill_dp_size = self.kv_mgr.prefill_dp_size_table[
self.bootstrap_addr
]
self.prefill_pp_size = self.kv_mgr.prefill_pp_size_table[
self.bootstrap_addr
]
# Currently, we don't allow prefill instance and decode instance to
# have different TP sizes per DP rank, except for models using MLA.
if self.kv_mgr.attn_tp_size == self.prefill_attn_tp_size:
self.target_tp_rank = (
self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size
)
self.required_dst_info_num = 1
self.required_prefill_response_num = 1 * (
self.prefill_pp_size // self.kv_mgr.pp_size
)
self.target_tp_ranks = [self.target_tp_rank]
elif self.kv_mgr.attn_tp_size > self.prefill_attn_tp_size:
if not self.kv_mgr.is_mla_backend:
logger.warning_once(
"Performance is NOT guaranteed when using different TP sizes for non-MLA models. "
)
self.target_tp_rank = (
self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size
) // (self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size)
self.required_dst_info_num = (
self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size
)
self.required_prefill_response_num = 1 * (
self.prefill_pp_size // self.kv_mgr.pp_size
)
self.target_tp_ranks = [self.target_tp_rank]
else:
if not self.kv_mgr.is_mla_backend:
logger.warning_once(
"Performance is NOT guaranteed when using different TP sizes for non-MLA models. "
)
# For non-MLA models, one decode rank needs to retrieve KVCache from multiple prefill ranks for non MLA models;
self.target_tp_ranks = [
rank
for rank in range(
(self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size)
* (self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size),
(self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size + 1)
* (self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size),
)
]
# For MLA models, we can retrieve KVCache from only one prefill rank, but we still need to maintain
# multiple connections in the connection pool and have to send dummy requests to other prefill ranks,
# or the KVPoll will never be set correctly
self.target_tp_rank = self.target_tp_ranks[0]
self.required_dst_info_num = 1
if self.kv_mgr.is_mla_backend:
self.required_prefill_response_num = (
self.prefill_pp_size // self.kv_mgr.pp_size
)
else:
self.required_prefill_response_num = (
self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size
) * (self.prefill_pp_size // self.kv_mgr.pp_size)
if self.data_parallel_rank is not None:
logger.debug(f"Targeting DP rank: {self.data_parallel_rank}")
self.target_dp_group = self.data_parallel_rank
else:
self.target_dp_group = bootstrap_room % self.prefill_dp_size
self.kv_mgr.required_prefill_response_num_table[self.bootstrap_room] = (
self.required_prefill_response_num
)
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
bootstrap_key = (
f"{self.bootstrap_addr}_{self.target_dp_group}_{self.target_tp_rank}"
)
if bootstrap_key not in self.kv_mgr.connection_pool:
bootstrap_infos = []
for target_tp_rank in self.target_tp_ranks:
for target_pp_rank in range(self.prefill_pp_size):
bootstrap_info = self._get_bootstrap_info_from_server(
target_tp_rank, self.target_dp_group, target_pp_rank
)
if bootstrap_info is not None:
if self.kv_mgr.is_mla_backend:
# For MLA: target_tp_rank is the selected real rank, others are dummy ranks
bootstrap_info["is_dummy"] = not bool(
target_tp_rank == self.target_tp_rank
or self.target_tp_rank is None
)
else:
# For non-MLA: all target_tp_ranks are selected real ranks
bootstrap_info["is_dummy"] = False
logger.debug(
f"Fetched bootstrap info: {bootstrap_info} for DP {self.target_dp_group} TP {target_tp_rank} PP {target_pp_rank}"
)
bootstrap_infos.append(bootstrap_info)
else:
self.kv_mgr.record_failure(
self.bootstrap_room,
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group} and target_pp_rank {target_pp_rank}",
)
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)
return
self.bootstrap_infos = bootstrap_infos
self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos
# Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server
self._register_kv_args()
else:
self.bootstrap_infos = self.kv_mgr.connection_pool[bootstrap_key]
assert len(self.bootstrap_infos) > 0
self.kv_mgr.addr_to_rooms_tracker[self.bootstrap_addr].add(self.bootstrap_room)
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.WaitingForInput)
def _get_bootstrap_info_from_server(
self, engine_rank, target_dp_group, target_pp_rank
):
"""Fetch the bootstrap info from the bootstrap server."""
try:
url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}&target_dp_group={target_dp_group}&target_pp_rank={target_pp_rank}"
response = requests.get(url, timeout=5)
if response.status_code == 200:
bootstrap_info = response.json()
return bootstrap_info
else:
logger.error(
f"Failed to get prefill server info: {response.status_code}, {response.text}"
)
return None
except Exception as e:
logger.error(f"Error fetching prefill info from bootstrap: {e}")
return None
def _get_prefill_parallel_info_from_server(
self,
) -> Tuple[Optional[int], Optional[int], Optional[int]]:
"""Fetch the prefill parallel info from the bootstrap server."""
try:
url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}&target_pp_rank={-1}"
response = requests.get(url)
if response.status_code == 200:
prefill_parallel_info = response.json()
return (
int(prefill_parallel_info["prefill_attn_tp_size"]),
int(prefill_parallel_info["prefill_dp_size"]),
int(prefill_parallel_info["prefill_pp_size"]),
)
else:
logger.error(
f"Failed to get prefill parallel info: {response.status_code}, {response.text}"
)
return None, None, None
except Exception as e:
logger.error(f"Error fetching prefill parallel info from bootstrap: {e}")
return None, None, None
def _register_kv_args(self):
for bootstrap_info in self.bootstrap_infos:
packed_kv_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
)
packed_aux_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
)
# Note(shangming): No need to add pp rank here since pp is not supported on the decode side yet
tp_rank = self.kv_mgr.kv_args.engine_rank
kv_item_len = self.kv_mgr.kv_args.kv_item_lens[0]
dst_tp_rank = str(tp_rank).encode("ascii")
dst_attn_tp_size = str(self.kv_mgr.attn_tp_size).encode("ascii")
dst_kv_item_len = str(kv_item_len).encode("ascii")
sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
with lock:
sock.send_multipart(
[
"None".encode("ascii"),
self.kv_mgr.local_ip.encode("ascii"),
str(self.kv_mgr.rank_port).encode("ascii"),
self.session_id.encode("ascii"),
packed_kv_data_ptrs,
packed_aux_data_ptrs,
dst_tp_rank,
dst_attn_tp_size,
dst_kv_item_len,
]
)
@classmethod
def _connect(cls, endpoint: str, is_ipv6: bool = False):
with cls._global_lock:
if endpoint not in cls._socket_cache:
sock = cls._ctx.socket(zmq.PUSH)
if is_ipv6:
sock.setsockopt(zmq.IPV6, 1)
sock.connect(endpoint)
cls._socket_cache[endpoint] = sock
cls._socket_locks[endpoint] = threading.Lock()
return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
@classmethod
def _connect_to_bootstrap_server(cls, bootstrap_info: dict):
ip_address = bootstrap_info["rank_ip"]
port = bootstrap_info["rank_port"]
is_ipv6_address = is_valid_ipv6_address(ip_address)
sock, lock = cls._connect(
format_tcp_address(ip_address, port), is_ipv6=is_ipv6_address
)
return sock, lock
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
for bootstrap_info in self.bootstrap_infos:
sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
is_dummy = bootstrap_info["is_dummy"]
with lock:
sock.send_multipart(
[
str(self.bootstrap_room).encode("ascii"),
self.kv_mgr.local_ip.encode("ascii"),
str(self.kv_mgr.rank_port).encode("ascii"),
self.session_id.encode("ascii"),
kv_indices.tobytes() if not is_dummy else b"",
str(aux_index).encode("ascii") if not is_dummy else b"",
str(self.required_dst_info_num).encode("ascii"),
]
)
self.init_time = time.time()
def poll(self) -> KVPoll:
if self.conclude_state is None:
status = self.kv_mgr.check_status(self.bootstrap_room)
if status in (KVPoll.Success, KVPoll.Failed):
self.conclude_state = status
elif status == KVPoll.WaitingForInput:
if self.init_time is not None:
now = time.time()
elapsed = now - self.init_time
if elapsed >= self.kv_mgr.waiting_timeout:
logger.warning_once(
"Some requests fail to receive KV Cache transfer done signal after bootstrapping. "
"If a greater mean TTFT is acceptable, you can 'export SGLANG_DISAGGREGATION_WAITING_TIMEOUT=600' (10 minutes) to relax the timeout condition. "
)
self.kv_mgr.record_failure(
self.bootstrap_room,
f"Request {self.bootstrap_room} timed out after {elapsed:.1f}s in KVPoll.WaitingForInput",
)
self.conclude_state = KVPoll.Failed
return KVPoll.Failed
return status
else:
return self.conclude_state
def clear(self) -> None:
if self.bootstrap_room in self.kv_mgr.request_status:
self.kv_mgr.request_status.pop(self.bootstrap_room)
if self.bootstrap_room in self.kv_mgr.required_prefill_response_num_table:
self.kv_mgr.required_prefill_response_num_table.pop(self.bootstrap_room)
if self.bootstrap_room in self.kv_mgr.prefill_response_tracker:
self.kv_mgr.prefill_response_tracker.pop(self.bootstrap_room)
def failure_exception(self):
# Explicitly set the status to failure since this request has failed in another rank
if self.conclude_state is None:
self.conclude_state = KVPoll.Failed
self.clear()
with self.kv_mgr.failure_lock:
failure_reason = self.kv_mgr.failure_records.pop(
self.bootstrap_room, "Failed due to an unknown reason from another rank"
)
raise KVTransferError(self.bootstrap_room, failure_reason)
def abort(self):
self.kv_mgr.record_failure(
self.bootstrap_room,
"Aborted by AbortReq.",
)
# Explicitly set the status to failure since this request has been aborted
self.conclude_state = KVPoll.Failed
class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
def __init__(self, host: str, port: int):
self.host = host
self.port = port
self.app = web.Application()
self.store = dict()
self.lock = asyncio.Lock()
self._setup_routes()
self.pp_size = None
self.attn_tp_size = None
self.dp_size = None
self.prefill_port_table: Dict[
int, Dict[int, Dict[int, Dict[str, Union[str, int]]]]
] = {}
# Start bootstrap server
self.thread = threading.Thread(target=self._run_server, daemon=True)
self.run()
def run(self):
self.thread.start()
def _setup_routes(self):
self.app.router.add_route("*", "/route", self._handle_route)
self.app.router.add_get("/health", self._handle_health_check)
async def _handle_health_check(self, request):
return web.Response(text="OK", status=200)
async def _handle_route(self, request: web.Request):
method = request.method
if method == "PUT":
return await self._handle_route_put(request)
elif method == "GET":
return await self._handle_route_get(request)
else:
return web.Response(
text="Method not allowed", status=405, content_type="application/json"
)
async def _handle_route_put(self, request: web.Request):
data = await request.json()
role = data["role"]
attn_tp_size = data["attn_tp_size"]
attn_tp_rank = data["attn_tp_rank"]
attn_dp_size = data["attn_dp_size"]
attn_dp_rank = data["attn_dp_rank"]
pp_size = data["pp_size"]
pp_rank = data["pp_rank"]
system_dp_size = data["system_dp_size"]
system_dp_rank = data["system_dp_rank"]
rank_ip = data["rank_ip"]
rank_port = int(data["rank_port"])
if self.attn_tp_size is None:
self.attn_tp_size = attn_tp_size
if self.dp_size is None:
self.dp_size = attn_dp_size if system_dp_size == 1 else system_dp_size
if self.pp_size is None:
self.pp_size = pp_size
if role == "Prefill":
if system_dp_size == 1:
dp_group = attn_dp_rank
else:
dp_group = system_dp_rank
# Add lock to make sure thread-safe
async with self.lock:
if dp_group not in self.prefill_port_table:
self.prefill_port_table[dp_group] = {}
if attn_tp_rank not in self.prefill_port_table[dp_group]:
self.prefill_port_table[dp_group][attn_tp_rank] = {}
self.prefill_port_table[dp_group][attn_tp_rank][pp_rank] = {
"rank_ip": rank_ip,
"rank_port": rank_port,
}
logger.debug(
f"Register prefill bootstrap: DP{dp_group} TP{attn_tp_rank} PP{pp_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
)
return web.Response(text="OK", status=200)
async def _handle_route_get(self, request: web.Request):
engine_rank = request.query.get("engine_rank")
target_dp_group = request.query.get("target_dp_group")
target_pp_rank = request.query.get("target_pp_rank")
if not engine_rank or not target_dp_group or not target_pp_rank:
return web.Response(text="Missing inputs for bootstrap server.", status=400)
# Currently we use engine_rank == -1 and target_dp_group == -1 to sync dp size
if (
int(engine_rank) == -1
and int(target_dp_group) == -1
and int(target_pp_rank) == -1
):
prefill_parallel_info = {
"prefill_attn_tp_size": self.attn_tp_size,
"prefill_dp_size": self.dp_size,
"prefill_pp_size": self.pp_size,
}
return web.json_response(prefill_parallel_info, status=200)
# Find corresponding prefill info
async with self.lock:
bootstrap_info = self.prefill_port_table[int(target_dp_group)][
int(engine_rank)
][int(target_pp_rank)]
if bootstrap_info is not None:
return web.json_response(bootstrap_info, status=200)
else:
return web.Response(text="Bootstrap info not Found", status=404)
def _run_server(self):
try:
# Event Loop
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
access_log = None
if logging.getLogger(__name__).getEffectiveLevel() <= logging.DEBUG:
access_log = self.app.logger
self._runner = web.AppRunner(self.app, access_log=access_log)
self._loop.run_until_complete(self._runner.setup())
site = web.TCPSite(self._runner, host=self.host, port=self.port)
self._loop.run_until_complete(site.start())
self._loop.run_forever()
except Exception as e:
logger.error(f"Server error: {str(e)}")
finally:
# Cleanup
self._loop.run_until_complete(self._runner.cleanup())
self._loop.close()
def close(self):
"""Shutdown"""
if self._loop is not None and self._loop.is_running():
self._loop.call_soon_threadsafe(self._loop.stop)
logger.info("Stopping server loop...")
if self.thread.is_alive():
self.thread.join(timeout=2)
logger.info("Server thread stopped")
def poll(self) -> KVPoll: ...