[PD Perf] replace Queue to FastQueue (#6649)
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com> Co-authored-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
@@ -31,6 +31,7 @@ from sglang.srt.disaggregation.base.conn import (
|
||||
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
|
||||
from sglang.srt.disaggregation.utils import (
|
||||
DisaggregationMode,
|
||||
FastQueue,
|
||||
group_concurrent_contiguous,
|
||||
)
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
@@ -151,7 +152,6 @@ class MooncakeKVManager(BaseKVManager):
|
||||
self.server_socket = zmq.Context().socket(zmq.PULL)
|
||||
self.register_buffer_to_engine()
|
||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
self.transfer_queue = queue.Queue()
|
||||
self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
|
||||
self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
|
||||
self.start_prefill_thread()
|
||||
@@ -159,15 +159,31 @@ class MooncakeKVManager(BaseKVManager):
|
||||
self.session_failures = defaultdict(int)
|
||||
self.failed_sessions = set()
|
||||
self.session_lock = threading.Lock()
|
||||
|
||||
# Determine the number of threads to use for kv sender
|
||||
cpu_count = os.cpu_count()
|
||||
self.executor = concurrent.futures.ThreadPoolExecutor(
|
||||
get_int_env_var(
|
||||
"SGLANG_DISAGGREGATION_THREAD_POOL_SIZE",
|
||||
min(max(1, cpu_count // 8), 8),
|
||||
)
|
||||
transfer_thread_pool_size = get_int_env_var(
|
||||
"SGLANG_DISAGGREGATION_THREAD_POOL_SIZE",
|
||||
min(max(4, int(0.75 * cpu_count) // 8), 12),
|
||||
)
|
||||
transfer_queue_size = get_int_env_var("SGLANG_DISAGGREGATION_QUEUE_SIZE", 4)
|
||||
self.transfer_queues: List[FastQueue] = [
|
||||
FastQueue() for _ in range(transfer_queue_size)
|
||||
]
|
||||
assert transfer_thread_pool_size >= transfer_queue_size, (
|
||||
f"The environment variable SGLANG_DISAGGREGATION_THREAD_POOL_SIZE={transfer_thread_pool_size} must be "
|
||||
f"greater than or equal to SGLANG_DISAGGREGATION_QUEUE_SIZE={transfer_queue_size}."
|
||||
)
|
||||
self.executors = [
|
||||
concurrent.futures.ThreadPoolExecutor(
|
||||
transfer_thread_pool_size // transfer_queue_size
|
||||
)
|
||||
for _ in range(transfer_queue_size)
|
||||
]
|
||||
for queue, executor in zip(self.transfer_queues, self.executors):
|
||||
threading.Thread(
|
||||
target=self.transfer_worker, args=(queue, executor), daemon=True
|
||||
).start()
|
||||
|
||||
self.bootstrap_time_out = get_int_env_var(
|
||||
"SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", 30
|
||||
)
|
||||
@@ -183,7 +199,7 @@ class MooncakeKVManager(BaseKVManager):
|
||||
)
|
||||
# Heartbeat failure should be at least 1
|
||||
self.max_failures = max(
|
||||
int(os.getenv("SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE", 2)), 1
|
||||
get_int_env_var("SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE", 2), 1
|
||||
)
|
||||
self.start_decode_thread()
|
||||
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
|
||||
@@ -220,6 +236,7 @@ class MooncakeKVManager(BaseKVManager):
|
||||
prefill_kv_indices: npt.NDArray[np.int64],
|
||||
dst_kv_ptrs: list[int],
|
||||
dst_kv_indices: npt.NDArray[np.int64],
|
||||
executor: concurrent.futures.ThreadPoolExecutor,
|
||||
):
|
||||
# Group by indices
|
||||
prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(
|
||||
@@ -251,7 +268,7 @@ class MooncakeKVManager(BaseKVManager):
|
||||
return 0
|
||||
|
||||
futures = [
|
||||
self.executor.submit(
|
||||
executor.submit(
|
||||
process_layer,
|
||||
src_ptr,
|
||||
dst_ptr,
|
||||
@@ -298,6 +315,123 @@ class MooncakeKVManager(BaseKVManager):
|
||||
]
|
||||
)
|
||||
|
||||
def transfer_worker(
|
||||
self, queue: FastQueue, executor: concurrent.futures.ThreadPoolExecutor
|
||||
):
|
||||
while True:
|
||||
try:
|
||||
kv_chunk: TransferKVChunk = queue.get()
|
||||
reqs_to_be_processed = (
|
||||
self.transfer_infos[kv_chunk.room].values()
|
||||
if kv_chunk.room in self.transfer_infos
|
||||
else []
|
||||
)
|
||||
polls = []
|
||||
dst_ranks_infos = []
|
||||
for req in reqs_to_be_processed:
|
||||
if not req.is_dummy:
|
||||
# Early exit if the request has failed
|
||||
with self.session_lock:
|
||||
if req.mooncake_session_id in self.failed_sessions:
|
||||
self.record_failure(
|
||||
kv_chunk.room,
|
||||
f"Decode instance could be dead, remote mooncake session {req.mooncake_session_id} is not alive",
|
||||
)
|
||||
self.update_status(kv_chunk.room, KVPoll.Failed)
|
||||
self.sync_status_to_decode_endpoint(
|
||||
req.endpoint,
|
||||
req.dst_port,
|
||||
req.room,
|
||||
KVPoll.Failed,
|
||||
)
|
||||
break
|
||||
|
||||
chunked_dst_kv_indice = req.dst_kv_indices[kv_chunk.index_slice]
|
||||
|
||||
# NOTE: This is temporarily a workaround to deal with the case where the prefill_kv_indices
|
||||
# is mismatched with the dst_kv_indices when page size > 1, this should never happen.
|
||||
if len(chunked_dst_kv_indice) < len(
|
||||
kv_chunk.prefill_kv_indices
|
||||
):
|
||||
kv_chunk.prefill_kv_indices = kv_chunk.prefill_kv_indices[
|
||||
len(chunked_dst_kv_indice)
|
||||
]
|
||||
logger.warning(
|
||||
f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}"
|
||||
)
|
||||
|
||||
ret = self.send_kvcache(
|
||||
req.mooncake_session_id,
|
||||
kv_chunk.prefill_kv_indices,
|
||||
self.decode_kv_args_table[
|
||||
req.mooncake_session_id
|
||||
].dst_kv_ptrs,
|
||||
chunked_dst_kv_indice,
|
||||
executor,
|
||||
)
|
||||
if ret != 0:
|
||||
with self.session_lock:
|
||||
self.session_failures[req.mooncake_session_id] += 1
|
||||
# Failures should never happen if the session is not dead, if the session fails once, mark it as failed
|
||||
if self.session_failures[req.mooncake_session_id] >= 1:
|
||||
self.failed_sessions.add(req.mooncake_session_id)
|
||||
logger.error(
|
||||
f"Session {req.mooncake_session_id} failed."
|
||||
)
|
||||
self.record_failure(
|
||||
kv_chunk.room,
|
||||
f"Failed to send kv chunk of {kv_chunk.room} to {req.endpoint}:{req.dst_port}",
|
||||
)
|
||||
self.update_status(kv_chunk.room, KVPoll.Failed)
|
||||
self.sync_status_to_decode_endpoint(
|
||||
req.endpoint, req.dst_port, req.room, KVPoll.Failed
|
||||
)
|
||||
break
|
||||
|
||||
if kv_chunk.is_last:
|
||||
# Only the last chunk we need to send the aux data
|
||||
ret = self.send_aux(
|
||||
req.mooncake_session_id,
|
||||
kv_chunk.prefill_aux_index,
|
||||
self.decode_kv_args_table[
|
||||
req.mooncake_session_id
|
||||
].dst_aux_ptrs,
|
||||
req.dst_aux_index,
|
||||
)
|
||||
polls.append(True if ret == 0 else False)
|
||||
dst_ranks_infos.append(
|
||||
(req.endpoint, req.dst_port, req.room)
|
||||
)
|
||||
|
||||
# Only sync status when all the dst ranks have received the kvcache
|
||||
if len(polls) == req.required_dst_info_num:
|
||||
status = KVPoll.Success if all(polls) else KVPoll.Failed
|
||||
self.update_status(req.room, status)
|
||||
for endpoint, dst_port, room in dst_ranks_infos:
|
||||
self.sync_status_to_decode_endpoint(
|
||||
endpoint, dst_port, room, status
|
||||
)
|
||||
else:
|
||||
# Dummy request means the decode instance is not used, so its status can be marked as success directly
|
||||
# Dummy request does not need to sync status to decode endpoint
|
||||
if kv_chunk.is_last and req.room in self.request_status:
|
||||
self.update_status(req.room, KVPoll.Success)
|
||||
|
||||
if (
|
||||
kv_chunk.room not in self.request_status
|
||||
or self.check_status(kv_chunk.room) == KVPoll.Success
|
||||
):
|
||||
if kv_chunk.room in self.transfer_infos:
|
||||
self.transfer_infos.pop(kv_chunk.room)
|
||||
|
||||
except queue.Empty:
|
||||
continue
|
||||
except Exception as e:
|
||||
# NOTE(shangming): Remove this when we make sure the transfer thread is bug-free
|
||||
raise RuntimeError(
|
||||
f"Transfer thread failed because of {e}. Prefill instance with bootstrap_port={self.bootstrap_port} is dead."
|
||||
)
|
||||
|
||||
def start_prefill_thread(self):
|
||||
self.rank_port = get_free_port()
|
||||
self.server_socket.bind(f"tcp://{get_local_ip_by_remote()}:{self.rank_port}")
|
||||
@@ -335,134 +469,7 @@ class MooncakeKVManager(BaseKVManager):
|
||||
if len(self.transfer_infos[room]) == required_dst_info_num:
|
||||
self.update_status(room, KVPoll.WaitingForInput)
|
||||
|
||||
def transfer_thread():
|
||||
# TODO: Shall we use KVPoll.Transferring state?
|
||||
while True:
|
||||
try:
|
||||
kv_chunk: TransferKVChunk = self.transfer_queue.get(timeout=0.01)
|
||||
reqs_to_be_processed = (
|
||||
self.transfer_infos[kv_chunk.room].values()
|
||||
if kv_chunk.room in self.transfer_infos
|
||||
else []
|
||||
)
|
||||
polls = []
|
||||
dst_ranks_infos = []
|
||||
for req in reqs_to_be_processed:
|
||||
if not req.is_dummy:
|
||||
# Early exit if the request has failed
|
||||
with self.session_lock:
|
||||
if req.mooncake_session_id in self.failed_sessions:
|
||||
self.record_failure(
|
||||
kv_chunk.room,
|
||||
f"Decode instance could be dead, remote mooncake session {req.mooncake_session_id} is not alive",
|
||||
)
|
||||
self.update_status(kv_chunk.room, KVPoll.Failed)
|
||||
self.sync_status_to_decode_endpoint(
|
||||
req.endpoint,
|
||||
req.dst_port,
|
||||
req.room,
|
||||
KVPoll.Failed,
|
||||
)
|
||||
break
|
||||
|
||||
chunked_dst_kv_indice = req.dst_kv_indices[
|
||||
kv_chunk.index_slice
|
||||
]
|
||||
|
||||
# NOTE: This is temporarily a workaround to deal with the case where the prefill_kv_indices
|
||||
# is mismatched with the dst_kv_indices when page size > 1, this should never happen.
|
||||
if len(chunked_dst_kv_indice) < len(
|
||||
kv_chunk.prefill_kv_indices
|
||||
):
|
||||
kv_chunk.prefill_kv_indices = (
|
||||
kv_chunk.prefill_kv_indices[
|
||||
len(chunked_dst_kv_indice)
|
||||
]
|
||||
)
|
||||
logger.warning(
|
||||
f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}"
|
||||
)
|
||||
|
||||
ret = self.send_kvcache(
|
||||
req.mooncake_session_id,
|
||||
kv_chunk.prefill_kv_indices,
|
||||
self.decode_kv_args_table[
|
||||
req.mooncake_session_id
|
||||
].dst_kv_ptrs,
|
||||
chunked_dst_kv_indice,
|
||||
)
|
||||
if ret != 0:
|
||||
with self.session_lock:
|
||||
self.session_failures[req.mooncake_session_id] += 1
|
||||
# Failures should never happen if the session is not dead, if the session fails once, mark it as failed
|
||||
if (
|
||||
self.session_failures[req.mooncake_session_id]
|
||||
>= 1
|
||||
):
|
||||
self.failed_sessions.add(
|
||||
req.mooncake_session_id
|
||||
)
|
||||
logger.error(
|
||||
f"Session {req.mooncake_session_id} failed."
|
||||
)
|
||||
self.record_failure(
|
||||
kv_chunk.room,
|
||||
f"Failed to send kv chunk of {kv_chunk.room} to {req.endpoint}:{req.dst_port}",
|
||||
)
|
||||
self.update_status(kv_chunk.room, KVPoll.Failed)
|
||||
self.sync_status_to_decode_endpoint(
|
||||
req.endpoint, req.dst_port, req.room, KVPoll.Failed
|
||||
)
|
||||
break
|
||||
|
||||
if kv_chunk.is_last:
|
||||
# Only the last chunk we need to send the aux data
|
||||
ret = self.send_aux(
|
||||
req.mooncake_session_id,
|
||||
kv_chunk.prefill_aux_index,
|
||||
self.decode_kv_args_table[
|
||||
req.mooncake_session_id
|
||||
].dst_aux_ptrs,
|
||||
req.dst_aux_index,
|
||||
)
|
||||
polls.append(True if ret == 0 else False)
|
||||
dst_ranks_infos.append(
|
||||
(req.endpoint, req.dst_port, req.room)
|
||||
)
|
||||
|
||||
# Only sync status when all the dst ranks have received the kvcache
|
||||
if len(polls) == req.required_dst_info_num:
|
||||
status = (
|
||||
KVPoll.Success if all(polls) else KVPoll.Failed
|
||||
)
|
||||
self.update_status(req.room, status)
|
||||
for endpoint, dst_port, room in dst_ranks_infos:
|
||||
self.sync_status_to_decode_endpoint(
|
||||
endpoint, dst_port, room, status
|
||||
)
|
||||
else:
|
||||
# Dummy request means the decode instance is not used, so its status can be marked as success directly
|
||||
# Dummy request does not need to sync status to decode endpoint
|
||||
if kv_chunk.is_last and req.room in self.request_status:
|
||||
self.update_status(req.room, KVPoll.Success)
|
||||
|
||||
if (
|
||||
kv_chunk.room not in self.request_status
|
||||
or self.check_status(kv_chunk.room) == KVPoll.Success
|
||||
):
|
||||
if kv_chunk.room in self.transfer_infos:
|
||||
self.transfer_infos.pop(kv_chunk.room)
|
||||
|
||||
except queue.Empty:
|
||||
continue
|
||||
except Exception as e:
|
||||
# NOTE(shangming): Remove this when we make sure the transfer thread is bug-free
|
||||
raise RuntimeError(
|
||||
f"Transfer thread failed because of {e}. Prefill instance with bootstrap_port={self.bootstrap_port} is dead."
|
||||
)
|
||||
|
||||
threading.Thread(target=bootstrap_thread).start()
|
||||
threading.Thread(target=transfer_thread).start()
|
||||
|
||||
def start_decode_thread(self):
|
||||
self.rank_port = get_free_port()
|
||||
@@ -555,7 +562,14 @@ class MooncakeKVManager(BaseKVManager):
|
||||
)
|
||||
return
|
||||
|
||||
self.transfer_queue.put(
|
||||
# NOTE(shangming): sharding according to the dst_infos to make sure
|
||||
# requests with the same dst_sessions will be added into the same
|
||||
# queue, which enables early abort with failed sessions.
|
||||
dst_infos = self.transfer_infos[bootstrap_room].keys()
|
||||
session_port_sum = sum(int(session.split(":")[1]) for session in dst_infos)
|
||||
shard_idx = session_port_sum % len(self.transfer_queues)
|
||||
|
||||
self.transfer_queues[shard_idx].put(
|
||||
TransferKVChunk(
|
||||
room=bootstrap_room,
|
||||
prefill_kv_indices=kv_indices,
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
import dataclasses
|
||||
import os
|
||||
import random
|
||||
import threading
|
||||
import warnings
|
||||
from collections import deque
|
||||
from enum import Enum
|
||||
@@ -281,6 +282,25 @@ class MetadataBuffers:
|
||||
)
|
||||
|
||||
|
||||
class FastQueue:
|
||||
def __init__(self):
|
||||
self._buf = deque()
|
||||
self._cond = threading.Condition()
|
||||
|
||||
def put(self, item):
|
||||
with self._cond:
|
||||
self._buf.append(item)
|
||||
# wake up a thread of wait()
|
||||
self._cond.notify()
|
||||
|
||||
def get(self):
|
||||
with self._cond:
|
||||
# if queue is empty ,block until is notified()
|
||||
while not self._buf:
|
||||
self._cond.wait()
|
||||
return self._buf.popleft()
|
||||
|
||||
|
||||
def group_concurrent_contiguous(
|
||||
src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
|
||||
) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
|
||||
|
||||
Reference in New Issue
Block a user