[PD] Add support for different TP sizes per DP rank (#5922)
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
@@ -37,6 +37,7 @@ class BaseKVManager(ABC):
|
|||||||
args: KVArgs,
|
args: KVArgs,
|
||||||
disaggregation_mode: DisaggregationMode,
|
disaggregation_mode: DisaggregationMode,
|
||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
|
is_mla_backend: Optional[bool] = False,
|
||||||
): ...
|
): ...
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ from sglang.srt.disaggregation.utils import (
|
|||||||
ReqToMetadataIdxAllocator,
|
ReqToMetadataIdxAllocator,
|
||||||
TransferBackend,
|
TransferBackend,
|
||||||
get_kv_class,
|
get_kv_class,
|
||||||
|
is_mla_backend,
|
||||||
kv_to_page_indices,
|
kv_to_page_indices,
|
||||||
poll_and_all_reduce,
|
poll_and_all_reduce,
|
||||||
)
|
)
|
||||||
@@ -87,6 +88,7 @@ class DecodePreallocQueue:
|
|||||||
self.req_to_token_pool = req_to_token_pool
|
self.req_to_token_pool = req_to_token_pool
|
||||||
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
||||||
self.token_to_kv_pool = token_to_kv_pool_allocator.get_kvcache()
|
self.token_to_kv_pool = token_to_kv_pool_allocator.get_kvcache()
|
||||||
|
self.is_mla_backend = is_mla_backend(self.token_to_kv_pool)
|
||||||
self.aux_dtype = aux_dtype
|
self.aux_dtype = aux_dtype
|
||||||
self.metadata_buffers = metadata_buffers
|
self.metadata_buffers = metadata_buffers
|
||||||
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
|
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
|
||||||
@@ -131,7 +133,10 @@ class DecodePreallocQueue:
|
|||||||
kv_args.gpu_id = self.scheduler.gpu_id
|
kv_args.gpu_id = self.scheduler.gpu_id
|
||||||
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
|
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
|
||||||
kv_manager = kv_manager_class(
|
kv_manager = kv_manager_class(
|
||||||
kv_args, DisaggregationMode.DECODE, self.scheduler.server_args
|
kv_args,
|
||||||
|
DisaggregationMode.DECODE,
|
||||||
|
self.scheduler.server_args,
|
||||||
|
self.is_mla_backend,
|
||||||
)
|
)
|
||||||
return kv_manager
|
return kv_manager
|
||||||
|
|
||||||
|
|||||||
@@ -68,16 +68,28 @@ class TransferInfo:
|
|||||||
mooncake_session_id: str
|
mooncake_session_id: str
|
||||||
dst_kv_indices: npt.NDArray[np.int64]
|
dst_kv_indices: npt.NDArray[np.int64]
|
||||||
dst_aux_index: int
|
dst_aux_index: int
|
||||||
|
required_dst_info_num: int
|
||||||
|
is_dummy: bool
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_zmq(cls, msg: List[bytes]):
|
def from_zmq(cls, msg: List[bytes]):
|
||||||
|
if msg[4] == b"" and msg[5] == b"":
|
||||||
|
is_dummy = True
|
||||||
|
dst_kv_indices = np.array([], dtype=np.int64)
|
||||||
|
dst_aux_index = None
|
||||||
|
else:
|
||||||
|
dst_kv_indices = np.frombuffer(msg[4], dtype=np.int64)
|
||||||
|
dst_aux_index = int(msg[5].decode("ascii"))
|
||||||
|
is_dummy = False
|
||||||
return cls(
|
return cls(
|
||||||
room=int(msg[0].decode("ascii")),
|
room=int(msg[0].decode("ascii")),
|
||||||
endpoint=msg[1].decode("ascii"),
|
endpoint=msg[1].decode("ascii"),
|
||||||
dst_port=int(msg[2].decode("ascii")),
|
dst_port=int(msg[2].decode("ascii")),
|
||||||
mooncake_session_id=msg[3].decode("ascii"),
|
mooncake_session_id=msg[3].decode("ascii"),
|
||||||
dst_kv_indices=np.frombuffer(msg[4], dtype=np.int64),
|
dst_kv_indices=dst_kv_indices,
|
||||||
dst_aux_index=int(msg[5].decode("ascii")),
|
dst_aux_index=dst_aux_index,
|
||||||
|
required_dst_info_num=int(msg[6].decode("ascii")),
|
||||||
|
is_dummy=is_dummy,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -108,6 +120,7 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
args: KVArgs,
|
args: KVArgs,
|
||||||
disaggregation_mode: DisaggregationMode,
|
disaggregation_mode: DisaggregationMode,
|
||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
|
is_mla_backend: Optional[bool] = False,
|
||||||
):
|
):
|
||||||
self.kv_args = args
|
self.kv_args = args
|
||||||
self.engine = MooncakeTransferEngine(
|
self.engine = MooncakeTransferEngine(
|
||||||
@@ -115,6 +128,7 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
gpu_id=self.kv_args.gpu_id,
|
gpu_id=self.kv_args.gpu_id,
|
||||||
ib_device=self.kv_args.ib_device,
|
ib_device=self.kv_args.ib_device,
|
||||||
)
|
)
|
||||||
|
self.is_mla_backend = is_mla_backend
|
||||||
self.disaggregation_mode = disaggregation_mode
|
self.disaggregation_mode = disaggregation_mode
|
||||||
# for p/d multi node infer
|
# for p/d multi node infer
|
||||||
self.bootstrap_port = server_args.disaggregation_bootstrap_port
|
self.bootstrap_port = server_args.disaggregation_bootstrap_port
|
||||||
@@ -132,7 +146,7 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
self.register_buffer_to_engine()
|
self.register_buffer_to_engine()
|
||||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||||
self.transfer_queue = queue.Queue()
|
self.transfer_queue = queue.Queue()
|
||||||
self.transfer_infos: Dict[int, TransferInfo] = {}
|
self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
|
||||||
self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
|
self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
|
||||||
self.start_prefill_thread()
|
self.start_prefill_thread()
|
||||||
self._register_to_bootstrap()
|
self._register_to_bootstrap()
|
||||||
@@ -145,6 +159,7 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||||
self.start_decode_thread()
|
self.start_decode_thread()
|
||||||
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
|
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
|
||||||
|
self.prefill_tp_size_table: Dict[str, int] = {}
|
||||||
self.prefill_dp_size_table: Dict[str, int] = {}
|
self.prefill_dp_size_table: Dict[str, int] = {}
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -218,7 +233,7 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
status = future.result()
|
status = future.result()
|
||||||
if status != 0:
|
if status != 0:
|
||||||
# Immediate shutdown on first error (existing tasks will finish)
|
# Immediate shutdown on first error (existing tasks will finish)
|
||||||
executor.shutdown(wait=False)
|
self.executor.shutdown(wait=False)
|
||||||
for f in futures:
|
for f in futures:
|
||||||
f.cancel()
|
f.cancel()
|
||||||
return status
|
return status
|
||||||
@@ -250,7 +265,7 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
self._connect("tcp://" + remote + ":" + str(dst_port)).send_multipart(
|
self._connect("tcp://" + remote + ":" + str(dst_port)).send_multipart(
|
||||||
[
|
[
|
||||||
str(room).encode("ascii"),
|
str(room).encode("ascii"),
|
||||||
str(self.request_status[room]).encode("ascii"),
|
str(self.check_status(room)).encode("ascii"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -264,8 +279,8 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
while True:
|
while True:
|
||||||
waiting_req_bytes = self.server_socket.recv_multipart()
|
waiting_req_bytes = self.server_socket.recv_multipart()
|
||||||
room = waiting_req_bytes[0].decode("ascii")
|
room = waiting_req_bytes[0].decode("ascii")
|
||||||
if room == "None":
|
|
||||||
mooncake_session_id = waiting_req_bytes[3].decode("ascii")
|
mooncake_session_id = waiting_req_bytes[3].decode("ascii")
|
||||||
|
if room == "None":
|
||||||
self.decode_kv_args_table[mooncake_session_id] = (
|
self.decode_kv_args_table[mooncake_session_id] = (
|
||||||
KVArgsRegisterInfo.from_zmq(waiting_req_bytes)
|
KVArgsRegisterInfo.from_zmq(waiting_req_bytes)
|
||||||
)
|
)
|
||||||
@@ -273,19 +288,32 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
f"Register KVArgs from {mooncake_session_id} successfully"
|
f"Register KVArgs from {mooncake_session_id} successfully"
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
else:
|
||||||
|
required_dst_info_num = int(waiting_req_bytes[6].decode("ascii"))
|
||||||
room = int(room)
|
room = int(room)
|
||||||
self.transfer_infos[room] = TransferInfo.from_zmq(waiting_req_bytes)
|
if room not in self.transfer_infos:
|
||||||
|
self.transfer_infos[room] = {}
|
||||||
|
|
||||||
|
self.transfer_infos[room][mooncake_session_id] = (
|
||||||
|
TransferInfo.from_zmq(waiting_req_bytes)
|
||||||
|
)
|
||||||
# NOTE: after bootstrapping we can mark the req as waiting for input
|
# NOTE: after bootstrapping we can mark the req as waiting for input
|
||||||
self.request_status[room] = KVPoll.WaitingForInput
|
if len(self.transfer_infos[room]) == required_dst_info_num:
|
||||||
|
self.update_status(room, KVPoll.WaitingForInput)
|
||||||
|
|
||||||
def transfer_thread():
|
def transfer_thread():
|
||||||
# TODO: Shall we use KVPoll.Transferring state?
|
# TODO: Shall we use KVPoll.Transferring state?
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
kv_chunk: TransferKVChunk = self.transfer_queue.get(timeout=0.01)
|
kv_chunk: TransferKVChunk = self.transfer_queue.get(timeout=0.01)
|
||||||
req = self.transfer_infos[kv_chunk.room]
|
reqs_to_be_processed = self.transfer_infos[kv_chunk.room].values()
|
||||||
chunked_dst_kv_indice = req.dst_kv_indices[kv_chunk.index_slice]
|
polls = []
|
||||||
|
dst_ranks_infos = []
|
||||||
|
for req in reqs_to_be_processed:
|
||||||
|
if not req.is_dummy:
|
||||||
|
chunked_dst_kv_indice = req.dst_kv_indices[
|
||||||
|
kv_chunk.index_slice
|
||||||
|
]
|
||||||
assert len(chunked_dst_kv_indice) == len(
|
assert len(chunked_dst_kv_indice) == len(
|
||||||
kv_chunk.prefill_kv_indices
|
kv_chunk.prefill_kv_indices
|
||||||
), f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}"
|
), f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}"
|
||||||
@@ -293,11 +321,13 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
ret = self.send_kvcache(
|
ret = self.send_kvcache(
|
||||||
req.mooncake_session_id,
|
req.mooncake_session_id,
|
||||||
kv_chunk.prefill_kv_indices,
|
kv_chunk.prefill_kv_indices,
|
||||||
self.decode_kv_args_table[req.mooncake_session_id].dst_kv_ptrs,
|
self.decode_kv_args_table[
|
||||||
|
req.mooncake_session_id
|
||||||
|
].dst_kv_ptrs,
|
||||||
chunked_dst_kv_indice,
|
chunked_dst_kv_indice,
|
||||||
)
|
)
|
||||||
if ret != 0:
|
if ret != 0:
|
||||||
self.request_status[kv_chunk.room] = KVPoll.Failed
|
self.update_status(kv_chunk.room, KVPoll.Failed)
|
||||||
self.sync_status_to_decode_endpoint(
|
self.sync_status_to_decode_endpoint(
|
||||||
req.endpoint, req.dst_port, req.room
|
req.endpoint, req.dst_port, req.room
|
||||||
)
|
)
|
||||||
@@ -313,13 +343,29 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
].dst_aux_ptrs,
|
].dst_aux_ptrs,
|
||||||
req.dst_aux_index,
|
req.dst_aux_index,
|
||||||
)
|
)
|
||||||
self.request_status[req.room] = (
|
polls.append(True if ret == 0 else False)
|
||||||
KVPoll.Success if ret == 0 else KVPoll.Failed
|
dst_ranks_infos.append(
|
||||||
|
(req.endpoint, req.dst_port, req.room)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Only sync status when all the dst ranks have received the kvcache
|
||||||
|
if len(polls) == req.required_dst_info_num:
|
||||||
|
self.update_status(
|
||||||
|
req.room,
|
||||||
|
KVPoll.Success if all(polls) else KVPoll.Failed,
|
||||||
|
)
|
||||||
|
for endpoint, dst_port, room in dst_ranks_infos:
|
||||||
self.sync_status_to_decode_endpoint(
|
self.sync_status_to_decode_endpoint(
|
||||||
req.endpoint, req.dst_port, req.room
|
endpoint, dst_port, room
|
||||||
)
|
)
|
||||||
self.transfer_infos.pop(req.room)
|
else:
|
||||||
|
# Dummy request means the decode instance is not used, so its status can be marked as success directly
|
||||||
|
# Dummy request does not need to sync status to decode endpoint
|
||||||
|
if kv_chunk.is_last:
|
||||||
|
self.update_status(req.room, KVPoll.Success)
|
||||||
|
|
||||||
|
if self.check_status(kv_chunk.room) == KVPoll.Success:
|
||||||
|
self.transfer_infos.pop(kv_chunk.room)
|
||||||
|
|
||||||
except queue.Empty:
|
except queue.Empty:
|
||||||
continue
|
continue
|
||||||
@@ -336,7 +382,7 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
(bootstrap_room, status) = self.server_socket.recv_multipart()
|
(bootstrap_room, status) = self.server_socket.recv_multipart()
|
||||||
status = int(status.decode("ascii"))
|
status = int(status.decode("ascii"))
|
||||||
bootstrap_room = int(bootstrap_room.decode("ascii"))
|
bootstrap_room = int(bootstrap_room.decode("ascii"))
|
||||||
self.request_status[bootstrap_room] = status
|
self.update_status(bootstrap_room, status)
|
||||||
|
|
||||||
threading.Thread(target=decode_thread).start()
|
threading.Thread(target=decode_thread).start()
|
||||||
|
|
||||||
@@ -360,11 +406,9 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
prefill_aux_index=aux_index,
|
prefill_aux_index=aux_index,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.request_status[bootstrap_room] = KVPoll.WaitingForInput
|
self.update_status(bootstrap_room, KVPoll.WaitingForInput)
|
||||||
|
|
||||||
def check_status(self, bootstrap_room: int):
|
def check_status(self, bootstrap_room: int):
|
||||||
# TODO: do we really need the poll()?
|
|
||||||
|
|
||||||
return self.request_status[bootstrap_room]
|
return self.request_status[bootstrap_room]
|
||||||
|
|
||||||
def update_status(self, bootstrap_room: int, status: KVPoll):
|
def update_status(self, bootstrap_room: int, status: KVPoll):
|
||||||
@@ -469,54 +513,111 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|||||||
self.session_id = self.kv_mgr.get_session_id()
|
self.session_id = self.kv_mgr.get_session_id()
|
||||||
self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping)
|
self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping)
|
||||||
|
|
||||||
if not self.kv_mgr.enable_dp_attention:
|
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
|
||||||
# We assume dp_attention should be activated simultaneously for
|
self.prefill_tp_size, self.prefill_dp_size = (
|
||||||
# both prefill role and decode role. If the decode instance does
|
|
||||||
# not enable dp_attention, then dp_attention is not enabled on the
|
|
||||||
# prefill instance as well. Therefore, we should skip questioning
|
|
||||||
# the prefill dp size to reduce bootstrap overhead.
|
|
||||||
self.prefill_dp_size = 1
|
|
||||||
elif self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
|
|
||||||
self.prefill_dp_size, tp_size_per_dp_rank = (
|
|
||||||
self._get_prefill_dp_size_from_server()
|
self._get_prefill_dp_size_from_server()
|
||||||
)
|
)
|
||||||
# Currently, we don't allow prefill instance and decode instance to
|
if self.prefill_tp_size is None or self.prefill_dp_size is None:
|
||||||
# have different TP sizes per DP rank.
|
|
||||||
assert tp_size_per_dp_rank == self.kv_mgr.tp_size // self.kv_mgr.dp_size
|
|
||||||
if self.prefill_dp_size is None:
|
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Could not fetch prefill dp_size for bootstrap_addr: {self.bootstrap_addr}"
|
f"Could not fetch prefill parallel info for bootstrap_addr: {self.bootstrap_addr}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
self.kv_mgr.prefill_tp_size_table[self.bootstrap_addr] = (
|
||||||
|
self.prefill_tp_size
|
||||||
|
)
|
||||||
self.kv_mgr.prefill_dp_size_table[self.bootstrap_addr] = (
|
self.kv_mgr.prefill_dp_size_table[self.bootstrap_addr] = (
|
||||||
self.prefill_dp_size
|
self.prefill_dp_size
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
self.prefill_tp_size = self.kv_mgr.prefill_tp_size_table[
|
||||||
|
self.bootstrap_addr
|
||||||
|
]
|
||||||
self.prefill_dp_size = self.kv_mgr.prefill_dp_size_table[
|
self.prefill_dp_size = self.kv_mgr.prefill_dp_size_table[
|
||||||
self.bootstrap_addr
|
self.bootstrap_addr
|
||||||
]
|
]
|
||||||
|
|
||||||
# NOTE: key distinguished by bootstrap_addr and engine_rank
|
# Currently, we don't allow prefill instance and decode instance to
|
||||||
|
# have different TP sizes per DP rank, except for models using MLA.
|
||||||
|
local_tp_size_per_dp_rank = self.kv_mgr.tp_size // self.kv_mgr.dp_size
|
||||||
|
prefill_tp_size_per_dp_rank = self.prefill_tp_size // self.prefill_dp_size
|
||||||
|
if local_tp_size_per_dp_rank == prefill_tp_size_per_dp_rank:
|
||||||
|
self.target_tp_rank = (
|
||||||
|
self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
|
||||||
|
)
|
||||||
|
self.required_dst_info_num = 1
|
||||||
|
self.target_tp_ranks = [self.target_tp_rank]
|
||||||
|
elif local_tp_size_per_dp_rank > prefill_tp_size_per_dp_rank:
|
||||||
|
assert (
|
||||||
|
self.kv_mgr.is_mla_backend
|
||||||
|
), "PD with different TP sizes per DP rank is not yet supported for non-MLA models"
|
||||||
|
self.target_tp_rank = (
|
||||||
|
self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
|
||||||
|
) // (local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank)
|
||||||
|
self.required_dst_info_num = (
|
||||||
|
local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank
|
||||||
|
)
|
||||||
|
self.target_tp_ranks = [self.target_tp_rank]
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
self.kv_mgr.is_mla_backend
|
||||||
|
), "PD with different TP sizes per DP rank is not yet supported for non-MLA models"
|
||||||
|
|
||||||
|
# For non-MLA models, one decode rank needs to retrieve KVCache from multiple prefill ranks for non MLA models;
|
||||||
|
self.target_tp_ranks = [
|
||||||
|
rank
|
||||||
|
for rank in range(
|
||||||
|
(self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank)
|
||||||
|
* (prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank),
|
||||||
|
(self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank + 1)
|
||||||
|
* (prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
# For MLA models, we can retrieve KVCache from only one prefill rank, but we still need to maintain
|
||||||
|
# multiple connections in the connection pool and have to send dummy requests to other prefill ranks,
|
||||||
|
# or the KVPoll will never be set correctly
|
||||||
|
self.target_tp_rank = self.target_tp_ranks[0]
|
||||||
|
self.required_dst_info_num = 1
|
||||||
|
|
||||||
self.target_dp_group = bootstrap_room % self.prefill_dp_size
|
self.target_dp_group = bootstrap_room % self.prefill_dp_size
|
||||||
bootstrap_key = f"{self.bootstrap_addr}_{self.kv_mgr.kv_args.engine_rank}"
|
|
||||||
|
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
|
||||||
|
bootstrap_key = (
|
||||||
|
f"{self.bootstrap_addr}_{self.target_dp_group}_{self.target_tp_rank}"
|
||||||
|
)
|
||||||
|
|
||||||
if bootstrap_key not in self.kv_mgr.connection_pool:
|
if bootstrap_key not in self.kv_mgr.connection_pool:
|
||||||
self.bootstrap_info = self._get_bootstrap_info_from_server(
|
bootstrap_infos = []
|
||||||
self.kv_mgr.kv_args.engine_rank,
|
for target_tp_rank in self.target_tp_ranks:
|
||||||
|
bootstrap_info = self._get_bootstrap_info_from_server(
|
||||||
|
target_tp_rank,
|
||||||
self.target_dp_group,
|
self.target_dp_group,
|
||||||
)
|
)
|
||||||
if self.bootstrap_info is None:
|
if bootstrap_info is not None:
|
||||||
|
# NOTE: only support MLA for now: select one prefill rank as real rank
|
||||||
|
bootstrap_info["is_dummy"] = not bool(
|
||||||
|
target_tp_rank == self.target_tp_rank
|
||||||
|
or self.target_tp_rank is None
|
||||||
|
)
|
||||||
|
bootstrap_infos.append(bootstrap_info)
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group}"
|
||||||
|
)
|
||||||
|
self.bootstrap_infos = bootstrap_infos
|
||||||
|
|
||||||
|
if len(self.bootstrap_infos) == 0:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}"
|
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_info
|
self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos
|
||||||
# Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server
|
# Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server
|
||||||
self._register_kv_args()
|
self._register_kv_args()
|
||||||
else:
|
else:
|
||||||
self.bootstrap_info = self.kv_mgr.connection_pool[bootstrap_key]
|
self.bootstrap_infos = self.kv_mgr.connection_pool[bootstrap_key]
|
||||||
|
|
||||||
assert self.bootstrap_info is not None
|
assert len(self.bootstrap_infos) > 0
|
||||||
self.kv_mgr.update_status(bootstrap_room, KVPoll.WaitingForInput)
|
self.kv_mgr.update_status(bootstrap_room, KVPoll.WaitingForInput)
|
||||||
|
|
||||||
def _get_bootstrap_info_from_server(self, engine_rank, target_dp_group):
|
def _get_bootstrap_info_from_server(self, engine_rank, target_dp_group):
|
||||||
@@ -543,8 +644,8 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|||||||
response = requests.get(url)
|
response = requests.get(url)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
prefill_parallel_info = response.json()
|
prefill_parallel_info = response.json()
|
||||||
return int(prefill_parallel_info["prefill_dp_size"]), int(
|
return int(prefill_parallel_info["prefill_tp_size"]), int(
|
||||||
prefill_parallel_info["tp_size_per_dp_rank"]
|
prefill_parallel_info["prefill_dp_size"]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.error(
|
logger.error(
|
||||||
@@ -556,16 +657,17 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def _register_kv_args(self):
|
def _register_kv_args(self):
|
||||||
|
for bootstrap_info in self.bootstrap_infos:
|
||||||
self.prefill_server_url = (
|
self.prefill_server_url = (
|
||||||
f"{self.bootstrap_info['rank_ip']}:{self.bootstrap_info['rank_port']}"
|
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
|
||||||
)
|
)
|
||||||
|
|
||||||
packed_kv_data_ptrs = b"".join(
|
packed_kv_data_ptrs = b"".join(
|
||||||
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
|
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
|
||||||
)
|
)
|
||||||
packed_aux_data_ptrs = b"".join(
|
packed_aux_data_ptrs = b"".join(
|
||||||
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
|
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
|
||||||
)
|
)
|
||||||
|
|
||||||
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
||||||
with lock:
|
with lock:
|
||||||
sock.send_multipart(
|
sock.send_multipart(
|
||||||
@@ -590,12 +692,14 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|||||||
return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
|
return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
|
||||||
|
|
||||||
def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
|
def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
|
||||||
|
for bootstrap_info in self.bootstrap_infos:
|
||||||
self.prefill_server_url = (
|
self.prefill_server_url = (
|
||||||
f"{self.bootstrap_info['rank_ip']}:{self.bootstrap_info['rank_port']}"
|
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
|
||||||
)
|
)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Fetched bootstrap info: {self.bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
|
f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
|
||||||
)
|
)
|
||||||
|
is_dummy = bootstrap_info["is_dummy"]
|
||||||
|
|
||||||
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
||||||
with lock:
|
with lock:
|
||||||
@@ -605,8 +709,9 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|||||||
get_local_ip_by_remote().encode("ascii"),
|
get_local_ip_by_remote().encode("ascii"),
|
||||||
str(self.kv_mgr.rank_port).encode("ascii"),
|
str(self.kv_mgr.rank_port).encode("ascii"),
|
||||||
self.session_id.encode("ascii"),
|
self.session_id.encode("ascii"),
|
||||||
kv_indices.tobytes(),
|
kv_indices.tobytes() if not is_dummy else b"",
|
||||||
str(aux_index).encode("ascii"),
|
str(aux_index).encode("ascii") if not is_dummy else b"",
|
||||||
|
str(self.required_dst_info_num).encode("ascii"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -624,6 +729,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
|
|||||||
self.store = dict()
|
self.store = dict()
|
||||||
self.lock = asyncio.Lock()
|
self.lock = asyncio.Lock()
|
||||||
self._setup_routes()
|
self._setup_routes()
|
||||||
|
self.tp_size = None
|
||||||
self.dp_size = None
|
self.dp_size = None
|
||||||
self.tp_size_per_dp_rank = None
|
self.tp_size_per_dp_rank = None
|
||||||
self.prefill_port_table: Dict[int, Dict[int, Dict[str, Union[str, int]]]] = {}
|
self.prefill_port_table: Dict[int, Dict[int, Dict[str, Union[str, int]]]] = {}
|
||||||
@@ -658,6 +764,9 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
|
|||||||
rank_port = int(data["rank_port"])
|
rank_port = int(data["rank_port"])
|
||||||
engine_rank = int(data["engine_rank"])
|
engine_rank = int(data["engine_rank"])
|
||||||
|
|
||||||
|
if self.tp_size is None:
|
||||||
|
self.tp_size = tp_size
|
||||||
|
|
||||||
if self.dp_size is None:
|
if self.dp_size is None:
|
||||||
self.dp_size = dp_size
|
self.dp_size = dp_size
|
||||||
|
|
||||||
@@ -693,17 +802,15 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
|
|||||||
# Currently we use engine_rank == -1 and target_dp_group == -1 to sync dp size
|
# Currently we use engine_rank == -1 and target_dp_group == -1 to sync dp size
|
||||||
if int(engine_rank) == -1 and int(target_dp_group) == -1:
|
if int(engine_rank) == -1 and int(target_dp_group) == -1:
|
||||||
prefill_parallel_info = {
|
prefill_parallel_info = {
|
||||||
|
"prefill_tp_size": self.tp_size,
|
||||||
"prefill_dp_size": self.dp_size,
|
"prefill_dp_size": self.dp_size,
|
||||||
"tp_size_per_dp_rank": self.tp_size_per_dp_rank,
|
|
||||||
}
|
}
|
||||||
return web.json_response(prefill_parallel_info, status=200)
|
return web.json_response(prefill_parallel_info, status=200)
|
||||||
|
|
||||||
# Find corresponding prefill info
|
# Find corresponding prefill info
|
||||||
tp_rank_in_dp_group = int(engine_rank) % self.tp_size_per_dp_rank
|
|
||||||
|
|
||||||
async with self.lock:
|
async with self.lock:
|
||||||
bootstrap_info = self.prefill_port_table[int(target_dp_group)][
|
bootstrap_info = self.prefill_port_table[int(target_dp_group)][
|
||||||
tp_rank_in_dp_group
|
int(engine_rank)
|
||||||
]
|
]
|
||||||
|
|
||||||
if bootstrap_info is not None:
|
if bootstrap_info is not None:
|
||||||
|
|||||||
@@ -132,6 +132,7 @@ class NixlKVManager(BaseKVManager):
|
|||||||
args: KVArgs,
|
args: KVArgs,
|
||||||
disaggregation_mode: DisaggregationMode,
|
disaggregation_mode: DisaggregationMode,
|
||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
|
is_mla_backend: Optional[bool] = False,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
from nixl._api import nixl_agent
|
from nixl._api import nixl_agent
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ from sglang.srt.disaggregation.utils import (
|
|||||||
ReqToMetadataIdxAllocator,
|
ReqToMetadataIdxAllocator,
|
||||||
TransferBackend,
|
TransferBackend,
|
||||||
get_kv_class,
|
get_kv_class,
|
||||||
|
is_mla_backend,
|
||||||
kv_to_page_indices,
|
kv_to_page_indices,
|
||||||
kv_to_page_num,
|
kv_to_page_num,
|
||||||
poll_and_all_reduce,
|
poll_and_all_reduce,
|
||||||
@@ -69,6 +70,7 @@ class PrefillBootstrapQueue:
|
|||||||
scheduler: Scheduler,
|
scheduler: Scheduler,
|
||||||
):
|
):
|
||||||
self.token_to_kv_pool = token_to_kv_pool
|
self.token_to_kv_pool = token_to_kv_pool
|
||||||
|
self.is_mla_backend = is_mla_backend(token_to_kv_pool)
|
||||||
self.aux_dtype = aux_dtype
|
self.aux_dtype = aux_dtype
|
||||||
|
|
||||||
self.metadata_buffers = metadata_buffers
|
self.metadata_buffers = metadata_buffers
|
||||||
@@ -112,7 +114,10 @@ class PrefillBootstrapQueue:
|
|||||||
kv_args.gpu_id = self.scheduler.gpu_id
|
kv_args.gpu_id = self.scheduler.gpu_id
|
||||||
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
|
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
|
||||||
kv_manager = kv_manager_class(
|
kv_manager = kv_manager_class(
|
||||||
kv_args, DisaggregationMode.PREFILL, self.scheduler.server_args
|
kv_args,
|
||||||
|
DisaggregationMode.PREFILL,
|
||||||
|
self.scheduler.server_args,
|
||||||
|
self.is_mla_backend,
|
||||||
)
|
)
|
||||||
return kv_manager
|
return kv_manager
|
||||||
|
|
||||||
|
|||||||
@@ -162,3 +162,9 @@ def register_disaggregation_server(
|
|||||||
warnings.warn(
|
warnings.warn(
|
||||||
f"Failed to register disaggregation server: {res.status_code} {res.text}"
|
f"Failed to register disaggregation server: {res.status_code} {res.text}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_mla_backend(target_kv_pool) -> bool:
|
||||||
|
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
|
||||||
|
|
||||||
|
return isinstance(target_kv_pool, MLATokenToKVPool)
|
||||||
|
|||||||
@@ -103,6 +103,7 @@ suites = {
|
|||||||
# TestFile("test_moe_deepep_eval_accuracy_large.py", 250),
|
# TestFile("test_moe_deepep_eval_accuracy_large.py", 250),
|
||||||
TestFile("test_disaggregation.py", 210),
|
TestFile("test_disaggregation.py", 210),
|
||||||
TestFile("test_local_attn.py", 250),
|
TestFile("test_local_attn.py", 250),
|
||||||
|
TestFile("test_disaggregation_different_tp.py", 210),
|
||||||
TestFile("test_full_deepseek_v3.py", 250),
|
TestFile("test_full_deepseek_v3.py", 250),
|
||||||
TestFile("test_pp_single_node.py", 150),
|
TestFile("test_pp_single_node.py", 150),
|
||||||
],
|
],
|
||||||
|
|||||||
151
test/srt/test_disaggregation_different_tp.py
Normal file
151
test/srt/test_disaggregation_different_tp.py
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
import unittest
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from sglang.srt.utils import kill_process_tree
|
||||||
|
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_MODEL_NAME_FOR_TEST_MLA,
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
CustomTestCase,
|
||||||
|
popen_launch_pd_server,
|
||||||
|
run_with_timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDisaggregationMooncakeDifferentTP(CustomTestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
# Temporarily disable JIT DeepGEMM
|
||||||
|
cls.original_jit_deepgemm = os.environ.get("SGL_ENABLE_JIT_DEEPGEMM")
|
||||||
|
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false"
|
||||||
|
|
||||||
|
cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
|
||||||
|
cls.base_host = "127.0.0.1"
|
||||||
|
cls.base_port = int(DEFAULT_URL_FOR_TEST.split(":")[-1])
|
||||||
|
cls.lb_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.prefill_url = f"http://{cls.base_host}:{cls.base_port + 100}"
|
||||||
|
cls.decode_url = f"http://{cls.base_host}:{cls.base_port + 200}"
|
||||||
|
|
||||||
|
run_with_timeout(cls.start_prefill, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH)
|
||||||
|
run_with_timeout(cls.start_decode, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH)
|
||||||
|
|
||||||
|
cls.wait_server_ready(cls.prefill_url + "/health")
|
||||||
|
cls.wait_server_ready(cls.decode_url + "/health")
|
||||||
|
|
||||||
|
lb_command = [
|
||||||
|
"python3",
|
||||||
|
"-m",
|
||||||
|
"sglang.srt.disaggregation.mini_lb",
|
||||||
|
"--prefill",
|
||||||
|
cls.prefill_url,
|
||||||
|
"--decode",
|
||||||
|
cls.decode_url,
|
||||||
|
"--host",
|
||||||
|
cls.base_host,
|
||||||
|
"--port",
|
||||||
|
str(cls.base_port),
|
||||||
|
]
|
||||||
|
|
||||||
|
print("Starting load balancer:", " ".join(lb_command))
|
||||||
|
cls.process_lb = subprocess.Popen(
|
||||||
|
lb_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
||||||
|
)
|
||||||
|
cls.wait_server_ready(cls.lb_url + "/health")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def start_prefill(cls):
|
||||||
|
prefill_args = [
|
||||||
|
"--trust-remote-code",
|
||||||
|
"--disaggregation-mode",
|
||||||
|
"prefill",
|
||||||
|
"--host",
|
||||||
|
cls.base_host,
|
||||||
|
"--port",
|
||||||
|
str(cls.base_port + 100),
|
||||||
|
"--tp",
|
||||||
|
"4",
|
||||||
|
]
|
||||||
|
cls.process_prefill = popen_launch_pd_server(
|
||||||
|
cls.model,
|
||||||
|
cls.prefill_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=prefill_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def start_decode(cls):
|
||||||
|
decode_args = [
|
||||||
|
"--trust-remote-code",
|
||||||
|
"--disaggregation-mode",
|
||||||
|
"decode",
|
||||||
|
"--host",
|
||||||
|
cls.base_host,
|
||||||
|
"--port",
|
||||||
|
str(cls.base_port + 200),
|
||||||
|
"--tp",
|
||||||
|
"2",
|
||||||
|
"--base-gpu-id",
|
||||||
|
"4",
|
||||||
|
]
|
||||||
|
cls.process_decode = popen_launch_pd_server(
|
||||||
|
cls.model,
|
||||||
|
cls.decode_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=decode_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def wait_server_ready(cls, url, timeout=60):
|
||||||
|
start_time = time.time()
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
response = requests.get(url)
|
||||||
|
if response.status_code == 200:
|
||||||
|
print(f"Server {url} is ready")
|
||||||
|
return
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if time.time() - start_time > timeout:
|
||||||
|
raise RuntimeError(f"Server {url} failed to start in {timeout}s")
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
# Restore JIT DeepGEMM environment variable
|
||||||
|
if cls.original_jit_deepgemm is not None:
|
||||||
|
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = cls.original_jit_deepgemm
|
||||||
|
else:
|
||||||
|
os.environ.pop("SGL_ENABLE_JIT_DEEPGEMM", None)
|
||||||
|
|
||||||
|
for process in [cls.process_lb, cls.process_decode, cls.process_prefill]:
|
||||||
|
if process:
|
||||||
|
try:
|
||||||
|
kill_process_tree(process.pid)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error killing process {process.pid}: {e}")
|
||||||
|
|
||||||
|
def test_gsm8k(self):
|
||||||
|
args = SimpleNamespace(
|
||||||
|
num_shots=5,
|
||||||
|
data_path=None,
|
||||||
|
num_questions=200,
|
||||||
|
max_new_tokens=512,
|
||||||
|
parallel=128,
|
||||||
|
host="http://127.0.0.1",
|
||||||
|
port=int(self.lb_url.split(":")[-1]),
|
||||||
|
)
|
||||||
|
metrics = run_eval_few_shot_gsm8k(args)
|
||||||
|
print(f"Evaluation metrics: {metrics}")
|
||||||
|
|
||||||
|
self.assertGreater(metrics["accuracy"], 0.60)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user