[PD Perf] replace Queue to FastQueue (#6649)
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com> Co-authored-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
@@ -31,6 +31,7 @@ from sglang.srt.disaggregation.base.conn import (
|
|||||||
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
|
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
|
||||||
from sglang.srt.disaggregation.utils import (
|
from sglang.srt.disaggregation.utils import (
|
||||||
DisaggregationMode,
|
DisaggregationMode,
|
||||||
|
FastQueue,
|
||||||
group_concurrent_contiguous,
|
group_concurrent_contiguous,
|
||||||
)
|
)
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
@@ -151,7 +152,6 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
self.server_socket = zmq.Context().socket(zmq.PULL)
|
self.server_socket = zmq.Context().socket(zmq.PULL)
|
||||||
self.register_buffer_to_engine()
|
self.register_buffer_to_engine()
|
||||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||||
self.transfer_queue = queue.Queue()
|
|
||||||
self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
|
self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
|
||||||
self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
|
self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
|
||||||
self.start_prefill_thread()
|
self.start_prefill_thread()
|
||||||
@@ -159,15 +159,31 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
self.session_failures = defaultdict(int)
|
self.session_failures = defaultdict(int)
|
||||||
self.failed_sessions = set()
|
self.failed_sessions = set()
|
||||||
self.session_lock = threading.Lock()
|
self.session_lock = threading.Lock()
|
||||||
|
|
||||||
# Determine the number of threads to use for kv sender
|
# Determine the number of threads to use for kv sender
|
||||||
cpu_count = os.cpu_count()
|
cpu_count = os.cpu_count()
|
||||||
self.executor = concurrent.futures.ThreadPoolExecutor(
|
transfer_thread_pool_size = get_int_env_var(
|
||||||
get_int_env_var(
|
"SGLANG_DISAGGREGATION_THREAD_POOL_SIZE",
|
||||||
"SGLANG_DISAGGREGATION_THREAD_POOL_SIZE",
|
min(max(4, int(0.75 * cpu_count) // 8), 12),
|
||||||
min(max(1, cpu_count // 8), 8),
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
transfer_queue_size = get_int_env_var("SGLANG_DISAGGREGATION_QUEUE_SIZE", 4)
|
||||||
|
self.transfer_queues: List[FastQueue] = [
|
||||||
|
FastQueue() for _ in range(transfer_queue_size)
|
||||||
|
]
|
||||||
|
assert transfer_thread_pool_size >= transfer_queue_size, (
|
||||||
|
f"The environment variable SGLANG_DISAGGREGATION_THREAD_POOL_SIZE={transfer_thread_pool_size} must be "
|
||||||
|
f"greater than or equal to SGLANG_DISAGGREGATION_QUEUE_SIZE={transfer_queue_size}."
|
||||||
|
)
|
||||||
|
self.executors = [
|
||||||
|
concurrent.futures.ThreadPoolExecutor(
|
||||||
|
transfer_thread_pool_size // transfer_queue_size
|
||||||
|
)
|
||||||
|
for _ in range(transfer_queue_size)
|
||||||
|
]
|
||||||
|
for queue, executor in zip(self.transfer_queues, self.executors):
|
||||||
|
threading.Thread(
|
||||||
|
target=self.transfer_worker, args=(queue, executor), daemon=True
|
||||||
|
).start()
|
||||||
|
|
||||||
self.bootstrap_time_out = get_int_env_var(
|
self.bootstrap_time_out = get_int_env_var(
|
||||||
"SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", 30
|
"SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", 30
|
||||||
)
|
)
|
||||||
@@ -183,7 +199,7 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
)
|
)
|
||||||
# Heartbeat failure should be at least 1
|
# Heartbeat failure should be at least 1
|
||||||
self.max_failures = max(
|
self.max_failures = max(
|
||||||
int(os.getenv("SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE", 2)), 1
|
get_int_env_var("SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE", 2), 1
|
||||||
)
|
)
|
||||||
self.start_decode_thread()
|
self.start_decode_thread()
|
||||||
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
|
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
|
||||||
@@ -220,6 +236,7 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
prefill_kv_indices: npt.NDArray[np.int64],
|
prefill_kv_indices: npt.NDArray[np.int64],
|
||||||
dst_kv_ptrs: list[int],
|
dst_kv_ptrs: list[int],
|
||||||
dst_kv_indices: npt.NDArray[np.int64],
|
dst_kv_indices: npt.NDArray[np.int64],
|
||||||
|
executor: concurrent.futures.ThreadPoolExecutor,
|
||||||
):
|
):
|
||||||
# Group by indices
|
# Group by indices
|
||||||
prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(
|
prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(
|
||||||
@@ -251,7 +268,7 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
return 0
|
return 0
|
||||||
|
|
||||||
futures = [
|
futures = [
|
||||||
self.executor.submit(
|
executor.submit(
|
||||||
process_layer,
|
process_layer,
|
||||||
src_ptr,
|
src_ptr,
|
||||||
dst_ptr,
|
dst_ptr,
|
||||||
@@ -298,6 +315,123 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def transfer_worker(
|
||||||
|
self, queue: FastQueue, executor: concurrent.futures.ThreadPoolExecutor
|
||||||
|
):
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
kv_chunk: TransferKVChunk = queue.get()
|
||||||
|
reqs_to_be_processed = (
|
||||||
|
self.transfer_infos[kv_chunk.room].values()
|
||||||
|
if kv_chunk.room in self.transfer_infos
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
polls = []
|
||||||
|
dst_ranks_infos = []
|
||||||
|
for req in reqs_to_be_processed:
|
||||||
|
if not req.is_dummy:
|
||||||
|
# Early exit if the request has failed
|
||||||
|
with self.session_lock:
|
||||||
|
if req.mooncake_session_id in self.failed_sessions:
|
||||||
|
self.record_failure(
|
||||||
|
kv_chunk.room,
|
||||||
|
f"Decode instance could be dead, remote mooncake session {req.mooncake_session_id} is not alive",
|
||||||
|
)
|
||||||
|
self.update_status(kv_chunk.room, KVPoll.Failed)
|
||||||
|
self.sync_status_to_decode_endpoint(
|
||||||
|
req.endpoint,
|
||||||
|
req.dst_port,
|
||||||
|
req.room,
|
||||||
|
KVPoll.Failed,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
chunked_dst_kv_indice = req.dst_kv_indices[kv_chunk.index_slice]
|
||||||
|
|
||||||
|
# NOTE: This is temporarily a workaround to deal with the case where the prefill_kv_indices
|
||||||
|
# is mismatched with the dst_kv_indices when page size > 1, this should never happen.
|
||||||
|
if len(chunked_dst_kv_indice) < len(
|
||||||
|
kv_chunk.prefill_kv_indices
|
||||||
|
):
|
||||||
|
kv_chunk.prefill_kv_indices = kv_chunk.prefill_kv_indices[
|
||||||
|
len(chunked_dst_kv_indice)
|
||||||
|
]
|
||||||
|
logger.warning(
|
||||||
|
f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
ret = self.send_kvcache(
|
||||||
|
req.mooncake_session_id,
|
||||||
|
kv_chunk.prefill_kv_indices,
|
||||||
|
self.decode_kv_args_table[
|
||||||
|
req.mooncake_session_id
|
||||||
|
].dst_kv_ptrs,
|
||||||
|
chunked_dst_kv_indice,
|
||||||
|
executor,
|
||||||
|
)
|
||||||
|
if ret != 0:
|
||||||
|
with self.session_lock:
|
||||||
|
self.session_failures[req.mooncake_session_id] += 1
|
||||||
|
# Failures should never happen if the session is not dead, if the session fails once, mark it as failed
|
||||||
|
if self.session_failures[req.mooncake_session_id] >= 1:
|
||||||
|
self.failed_sessions.add(req.mooncake_session_id)
|
||||||
|
logger.error(
|
||||||
|
f"Session {req.mooncake_session_id} failed."
|
||||||
|
)
|
||||||
|
self.record_failure(
|
||||||
|
kv_chunk.room,
|
||||||
|
f"Failed to send kv chunk of {kv_chunk.room} to {req.endpoint}:{req.dst_port}",
|
||||||
|
)
|
||||||
|
self.update_status(kv_chunk.room, KVPoll.Failed)
|
||||||
|
self.sync_status_to_decode_endpoint(
|
||||||
|
req.endpoint, req.dst_port, req.room, KVPoll.Failed
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
if kv_chunk.is_last:
|
||||||
|
# Only the last chunk we need to send the aux data
|
||||||
|
ret = self.send_aux(
|
||||||
|
req.mooncake_session_id,
|
||||||
|
kv_chunk.prefill_aux_index,
|
||||||
|
self.decode_kv_args_table[
|
||||||
|
req.mooncake_session_id
|
||||||
|
].dst_aux_ptrs,
|
||||||
|
req.dst_aux_index,
|
||||||
|
)
|
||||||
|
polls.append(True if ret == 0 else False)
|
||||||
|
dst_ranks_infos.append(
|
||||||
|
(req.endpoint, req.dst_port, req.room)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only sync status when all the dst ranks have received the kvcache
|
||||||
|
if len(polls) == req.required_dst_info_num:
|
||||||
|
status = KVPoll.Success if all(polls) else KVPoll.Failed
|
||||||
|
self.update_status(req.room, status)
|
||||||
|
for endpoint, dst_port, room in dst_ranks_infos:
|
||||||
|
self.sync_status_to_decode_endpoint(
|
||||||
|
endpoint, dst_port, room, status
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Dummy request means the decode instance is not used, so its status can be marked as success directly
|
||||||
|
# Dummy request does not need to sync status to decode endpoint
|
||||||
|
if kv_chunk.is_last and req.room in self.request_status:
|
||||||
|
self.update_status(req.room, KVPoll.Success)
|
||||||
|
|
||||||
|
if (
|
||||||
|
kv_chunk.room not in self.request_status
|
||||||
|
or self.check_status(kv_chunk.room) == KVPoll.Success
|
||||||
|
):
|
||||||
|
if kv_chunk.room in self.transfer_infos:
|
||||||
|
self.transfer_infos.pop(kv_chunk.room)
|
||||||
|
|
||||||
|
except queue.Empty:
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
# NOTE(shangming): Remove this when we make sure the transfer thread is bug-free
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Transfer thread failed because of {e}. Prefill instance with bootstrap_port={self.bootstrap_port} is dead."
|
||||||
|
)
|
||||||
|
|
||||||
def start_prefill_thread(self):
|
def start_prefill_thread(self):
|
||||||
self.rank_port = get_free_port()
|
self.rank_port = get_free_port()
|
||||||
self.server_socket.bind(f"tcp://{get_local_ip_by_remote()}:{self.rank_port}")
|
self.server_socket.bind(f"tcp://{get_local_ip_by_remote()}:{self.rank_port}")
|
||||||
@@ -335,134 +469,7 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
if len(self.transfer_infos[room]) == required_dst_info_num:
|
if len(self.transfer_infos[room]) == required_dst_info_num:
|
||||||
self.update_status(room, KVPoll.WaitingForInput)
|
self.update_status(room, KVPoll.WaitingForInput)
|
||||||
|
|
||||||
def transfer_thread():
|
|
||||||
# TODO: Shall we use KVPoll.Transferring state?
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
kv_chunk: TransferKVChunk = self.transfer_queue.get(timeout=0.01)
|
|
||||||
reqs_to_be_processed = (
|
|
||||||
self.transfer_infos[kv_chunk.room].values()
|
|
||||||
if kv_chunk.room in self.transfer_infos
|
|
||||||
else []
|
|
||||||
)
|
|
||||||
polls = []
|
|
||||||
dst_ranks_infos = []
|
|
||||||
for req in reqs_to_be_processed:
|
|
||||||
if not req.is_dummy:
|
|
||||||
# Early exit if the request has failed
|
|
||||||
with self.session_lock:
|
|
||||||
if req.mooncake_session_id in self.failed_sessions:
|
|
||||||
self.record_failure(
|
|
||||||
kv_chunk.room,
|
|
||||||
f"Decode instance could be dead, remote mooncake session {req.mooncake_session_id} is not alive",
|
|
||||||
)
|
|
||||||
self.update_status(kv_chunk.room, KVPoll.Failed)
|
|
||||||
self.sync_status_to_decode_endpoint(
|
|
||||||
req.endpoint,
|
|
||||||
req.dst_port,
|
|
||||||
req.room,
|
|
||||||
KVPoll.Failed,
|
|
||||||
)
|
|
||||||
break
|
|
||||||
|
|
||||||
chunked_dst_kv_indice = req.dst_kv_indices[
|
|
||||||
kv_chunk.index_slice
|
|
||||||
]
|
|
||||||
|
|
||||||
# NOTE: This is temporarily a workaround to deal with the case where the prefill_kv_indices
|
|
||||||
# is mismatched with the dst_kv_indices when page size > 1, this should never happen.
|
|
||||||
if len(chunked_dst_kv_indice) < len(
|
|
||||||
kv_chunk.prefill_kv_indices
|
|
||||||
):
|
|
||||||
kv_chunk.prefill_kv_indices = (
|
|
||||||
kv_chunk.prefill_kv_indices[
|
|
||||||
len(chunked_dst_kv_indice)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
logger.warning(
|
|
||||||
f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
ret = self.send_kvcache(
|
|
||||||
req.mooncake_session_id,
|
|
||||||
kv_chunk.prefill_kv_indices,
|
|
||||||
self.decode_kv_args_table[
|
|
||||||
req.mooncake_session_id
|
|
||||||
].dst_kv_ptrs,
|
|
||||||
chunked_dst_kv_indice,
|
|
||||||
)
|
|
||||||
if ret != 0:
|
|
||||||
with self.session_lock:
|
|
||||||
self.session_failures[req.mooncake_session_id] += 1
|
|
||||||
# Failures should never happen if the session is not dead, if the session fails once, mark it as failed
|
|
||||||
if (
|
|
||||||
self.session_failures[req.mooncake_session_id]
|
|
||||||
>= 1
|
|
||||||
):
|
|
||||||
self.failed_sessions.add(
|
|
||||||
req.mooncake_session_id
|
|
||||||
)
|
|
||||||
logger.error(
|
|
||||||
f"Session {req.mooncake_session_id} failed."
|
|
||||||
)
|
|
||||||
self.record_failure(
|
|
||||||
kv_chunk.room,
|
|
||||||
f"Failed to send kv chunk of {kv_chunk.room} to {req.endpoint}:{req.dst_port}",
|
|
||||||
)
|
|
||||||
self.update_status(kv_chunk.room, KVPoll.Failed)
|
|
||||||
self.sync_status_to_decode_endpoint(
|
|
||||||
req.endpoint, req.dst_port, req.room, KVPoll.Failed
|
|
||||||
)
|
|
||||||
break
|
|
||||||
|
|
||||||
if kv_chunk.is_last:
|
|
||||||
# Only the last chunk we need to send the aux data
|
|
||||||
ret = self.send_aux(
|
|
||||||
req.mooncake_session_id,
|
|
||||||
kv_chunk.prefill_aux_index,
|
|
||||||
self.decode_kv_args_table[
|
|
||||||
req.mooncake_session_id
|
|
||||||
].dst_aux_ptrs,
|
|
||||||
req.dst_aux_index,
|
|
||||||
)
|
|
||||||
polls.append(True if ret == 0 else False)
|
|
||||||
dst_ranks_infos.append(
|
|
||||||
(req.endpoint, req.dst_port, req.room)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Only sync status when all the dst ranks have received the kvcache
|
|
||||||
if len(polls) == req.required_dst_info_num:
|
|
||||||
status = (
|
|
||||||
KVPoll.Success if all(polls) else KVPoll.Failed
|
|
||||||
)
|
|
||||||
self.update_status(req.room, status)
|
|
||||||
for endpoint, dst_port, room in dst_ranks_infos:
|
|
||||||
self.sync_status_to_decode_endpoint(
|
|
||||||
endpoint, dst_port, room, status
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Dummy request means the decode instance is not used, so its status can be marked as success directly
|
|
||||||
# Dummy request does not need to sync status to decode endpoint
|
|
||||||
if kv_chunk.is_last and req.room in self.request_status:
|
|
||||||
self.update_status(req.room, KVPoll.Success)
|
|
||||||
|
|
||||||
if (
|
|
||||||
kv_chunk.room not in self.request_status
|
|
||||||
or self.check_status(kv_chunk.room) == KVPoll.Success
|
|
||||||
):
|
|
||||||
if kv_chunk.room in self.transfer_infos:
|
|
||||||
self.transfer_infos.pop(kv_chunk.room)
|
|
||||||
|
|
||||||
except queue.Empty:
|
|
||||||
continue
|
|
||||||
except Exception as e:
|
|
||||||
# NOTE(shangming): Remove this when we make sure the transfer thread is bug-free
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Transfer thread failed because of {e}. Prefill instance with bootstrap_port={self.bootstrap_port} is dead."
|
|
||||||
)
|
|
||||||
|
|
||||||
threading.Thread(target=bootstrap_thread).start()
|
threading.Thread(target=bootstrap_thread).start()
|
||||||
threading.Thread(target=transfer_thread).start()
|
|
||||||
|
|
||||||
def start_decode_thread(self):
|
def start_decode_thread(self):
|
||||||
self.rank_port = get_free_port()
|
self.rank_port = get_free_port()
|
||||||
@@ -555,7 +562,14 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
self.transfer_queue.put(
|
# NOTE(shangming): sharding according to the dst_infos to make sure
|
||||||
|
# requests with the same dst_sessions will be added into the same
|
||||||
|
# queue, which enables early abort with failed sessions.
|
||||||
|
dst_infos = self.transfer_infos[bootstrap_room].keys()
|
||||||
|
session_port_sum = sum(int(session.split(":")[1]) for session in dst_infos)
|
||||||
|
shard_idx = session_port_sum % len(self.transfer_queues)
|
||||||
|
|
||||||
|
self.transfer_queues[shard_idx].put(
|
||||||
TransferKVChunk(
|
TransferKVChunk(
|
||||||
room=bootstrap_room,
|
room=bootstrap_room,
|
||||||
prefill_kv_indices=kv_indices,
|
prefill_kv_indices=kv_indices,
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
import threading
|
||||||
import warnings
|
import warnings
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@@ -281,6 +282,25 @@ class MetadataBuffers:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FastQueue:
|
||||||
|
def __init__(self):
|
||||||
|
self._buf = deque()
|
||||||
|
self._cond = threading.Condition()
|
||||||
|
|
||||||
|
def put(self, item):
|
||||||
|
with self._cond:
|
||||||
|
self._buf.append(item)
|
||||||
|
# wake up a thread of wait()
|
||||||
|
self._cond.notify()
|
||||||
|
|
||||||
|
def get(self):
|
||||||
|
with self._cond:
|
||||||
|
# if queue is empty ,block until is notified()
|
||||||
|
while not self._buf:
|
||||||
|
self._cond.wait()
|
||||||
|
return self._buf.popleft()
|
||||||
|
|
||||||
|
|
||||||
def group_concurrent_contiguous(
|
def group_concurrent_contiguous(
|
||||||
src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
|
src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
|
||||||
) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
|
) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
|
||||||
|
|||||||
Reference in New Issue
Block a user