[PD] NIXL: Register kv args in advance and cleanup finished requests (#6717)
This commit is contained in:
@@ -31,23 +31,19 @@ from sglang.srt.utils import get_local_ip_by_remote
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
NixlEngineInfo: TypeAlias = Dict[str, Union[str, int]]
|
|
||||||
|
|
||||||
GUARD = "NixlMsgGuard".encode("ascii")
|
GUARD = "NixlMsgGuard".encode("ascii")
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class TransferInfo:
|
class TransferInfo:
|
||||||
|
"""Contains indices for a transfer, sent by KVReceiver. Received by prefill bootstrap thread."""
|
||||||
|
|
||||||
room: int
|
room: int
|
||||||
endpoint: str
|
endpoint: str
|
||||||
dst_port: int
|
dst_port: int
|
||||||
agent_metadata: bytes
|
|
||||||
agent_name: str
|
agent_name: str
|
||||||
dst_kv_ptrs: list[int]
|
|
||||||
dst_kv_indices: npt.NDArray[np.int32]
|
dst_kv_indices: npt.NDArray[np.int32]
|
||||||
dst_aux_ptrs: list[int]
|
|
||||||
dst_aux_index: int
|
dst_aux_index: int
|
||||||
dst_gpu_id: int
|
|
||||||
required_dst_info_num: int
|
required_dst_info_num: int
|
||||||
|
|
||||||
def is_dummy(self):
|
def is_dummy(self):
|
||||||
@@ -59,14 +55,37 @@ class TransferInfo:
|
|||||||
room=int(msg[0].decode("ascii")),
|
room=int(msg[0].decode("ascii")),
|
||||||
endpoint=msg[1].decode("ascii"),
|
endpoint=msg[1].decode("ascii"),
|
||||||
dst_port=int(msg[2].decode("ascii")),
|
dst_port=int(msg[2].decode("ascii")),
|
||||||
agent_metadata=msg[3],
|
agent_name=msg[3].decode("ascii"),
|
||||||
agent_name=msg[4].decode("ascii"),
|
dst_kv_indices=np.frombuffer(msg[4], dtype=np.int32),
|
||||||
|
dst_aux_index=int(msg[5].decode("ascii")),
|
||||||
|
required_dst_info_num=int(msg[6].decode("ascii")),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class KVArgsRegisterInfo:
|
||||||
|
"""Contains base pointers and other info which only needs to be sent once by KVReceiver. Received by prefill bootstrap thread."""
|
||||||
|
|
||||||
|
room: str
|
||||||
|
endpoint: str
|
||||||
|
dst_port: int
|
||||||
|
agent_name: str
|
||||||
|
agent_metadata: bytes
|
||||||
|
dst_kv_ptrs: list[int]
|
||||||
|
dst_aux_ptrs: list[int]
|
||||||
|
gpu_id: int
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_zmq(cls, msg: List[bytes]):
|
||||||
|
return cls(
|
||||||
|
room=str(msg[0].decode("ascii")),
|
||||||
|
endpoint=msg[1].decode("ascii"),
|
||||||
|
dst_port=int(msg[2].decode("ascii")),
|
||||||
|
agent_name=msg[3].decode("ascii"),
|
||||||
|
agent_metadata=msg[4],
|
||||||
dst_kv_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
|
dst_kv_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
|
||||||
dst_kv_indices=np.frombuffer(msg[6], dtype=np.int32),
|
dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
|
||||||
dst_aux_ptrs=list(struct.unpack(f"{len(msg[7])//8}Q", msg[7])),
|
gpu_id=int(msg[7].decode("ascii")),
|
||||||
dst_aux_index=int(msg[8].decode("ascii")),
|
|
||||||
dst_gpu_id=int(msg[9].decode("ascii")),
|
|
||||||
required_dst_info_num=int(msg[10].decode("ascii")),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -109,9 +128,9 @@ class NixlKVManager(CommonKVManager):
|
|||||||
self.register_buffer_to_engine()
|
self.register_buffer_to_engine()
|
||||||
|
|
||||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||||
self.request_status = {}
|
self.request_status: Dict[int, KVPoll] = {}
|
||||||
self.transfer_infos: Dict[int, TransferInfo] = {}
|
self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
|
||||||
self.peer_names: Dict[str, str] = {}
|
self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
|
||||||
self._start_bootstrap_thread()
|
self._start_bootstrap_thread()
|
||||||
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||||
self.transfer_statuses: Dict[int, TransferStatus] = defaultdict(
|
self.transfer_statuses: Dict[int, TransferStatus] = defaultdict(
|
||||||
@@ -154,10 +173,13 @@ class NixlKVManager(CommonKVManager):
|
|||||||
if not self.aux_descs:
|
if not self.aux_descs:
|
||||||
raise Exception("NIXL memory registration failed for aux tensors")
|
raise Exception("NIXL memory registration failed for aux tensors")
|
||||||
|
|
||||||
def _add_remote(self, agent_name: str, agent_metadata: bytes):
|
def _add_remote_peer(self, decode_kv_args: KVArgsRegisterInfo):
|
||||||
if agent_name not in self.peer_names:
|
agent_name = decode_kv_args.agent_name
|
||||||
self.peer_names[agent_name] = self.agent.add_remote_agent(agent_metadata)
|
if agent_name in self.decode_kv_args_table:
|
||||||
return self.peer_names[agent_name]
|
logger.info(f"Peer {agent_name} was already registered, ignoring.")
|
||||||
|
return
|
||||||
|
self.decode_kv_args_table[agent_name] = decode_kv_args
|
||||||
|
self.agent.add_remote_agent(decode_kv_args.agent_metadata)
|
||||||
|
|
||||||
def send_kvcache(
|
def send_kvcache(
|
||||||
self,
|
self,
|
||||||
@@ -262,17 +284,17 @@ class NixlKVManager(CommonKVManager):
|
|||||||
if req.is_dummy():
|
if req.is_dummy():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
peer_name = self._add_remote(req.agent_name, req.agent_metadata)
|
|
||||||
chunked_dst_kv_indice = req.dst_kv_indices[index_slice]
|
chunked_dst_kv_indice = req.dst_kv_indices[index_slice]
|
||||||
assert len(chunked_dst_kv_indice) == len(kv_indices)
|
assert len(chunked_dst_kv_indice) == len(kv_indices)
|
||||||
|
assert req.agent_name in self.decode_kv_args_table
|
||||||
|
|
||||||
notif = "_".join([str(req.room), "kv", str(chunk_id), str(int(is_last))])
|
notif = "_".join([str(req.room), "kv", str(chunk_id), str(int(is_last))])
|
||||||
kv_xfer_handle = self.send_kvcache(
|
kv_xfer_handle = self.send_kvcache(
|
||||||
peer_name,
|
req.agent_name,
|
||||||
kv_indices,
|
kv_indices,
|
||||||
req.dst_kv_ptrs,
|
self.decode_kv_args_table[req.agent_name].dst_kv_ptrs,
|
||||||
chunked_dst_kv_indice,
|
chunked_dst_kv_indice,
|
||||||
req.dst_gpu_id,
|
self.decode_kv_args_table[req.agent_name].gpu_id,
|
||||||
notif,
|
notif,
|
||||||
)
|
)
|
||||||
handles.append(kv_xfer_handle)
|
handles.append(kv_xfer_handle)
|
||||||
@@ -280,13 +302,15 @@ class NixlKVManager(CommonKVManager):
|
|||||||
if is_last:
|
if is_last:
|
||||||
assert aux_index is not None
|
assert aux_index is not None
|
||||||
aux_xfer_handle = self.send_aux(
|
aux_xfer_handle = self.send_aux(
|
||||||
peer_name,
|
req.agent_name,
|
||||||
aux_index,
|
aux_index,
|
||||||
req.dst_aux_ptrs,
|
self.decode_kv_args_table[req.agent_name].dst_aux_ptrs,
|
||||||
req.dst_aux_index,
|
req.dst_aux_index,
|
||||||
str(req.room) + "_aux",
|
str(req.room) + "_aux",
|
||||||
)
|
)
|
||||||
handles.append(aux_xfer_handle)
|
handles.append(aux_xfer_handle)
|
||||||
|
if is_last:
|
||||||
|
del self.transfer_infos[bootstrap_room]
|
||||||
return handles
|
return handles
|
||||||
|
|
||||||
def update_transfer_status(self):
|
def update_transfer_status(self):
|
||||||
@@ -328,16 +352,23 @@ class NixlKVManager(CommonKVManager):
|
|||||||
), f"First message should be {GUARD}. Foreign traffic?"
|
), f"First message should be {GUARD}. Foreign traffic?"
|
||||||
waiting_req_bytes = waiting_req_bytes[1:]
|
waiting_req_bytes = waiting_req_bytes[1:]
|
||||||
room = waiting_req_bytes[0].decode("ascii")
|
room = waiting_req_bytes[0].decode("ascii")
|
||||||
|
agent_name = waiting_req_bytes[3].decode("ascii")
|
||||||
required_dst_info_num = int(waiting_req_bytes[10].decode("ascii"))
|
if room == "None":
|
||||||
|
# Register new peer and save KV base pointers.
|
||||||
|
self._add_remote_peer(
|
||||||
|
KVArgsRegisterInfo.from_zmq(waiting_req_bytes)
|
||||||
|
)
|
||||||
|
logger.debug(f"Register KVArgs from {agent_name} successfully")
|
||||||
|
continue
|
||||||
room = int(room)
|
room = int(room)
|
||||||
agent_name = waiting_req_bytes[4].decode("ascii")
|
|
||||||
if room not in self.transfer_infos:
|
if room not in self.transfer_infos:
|
||||||
self.transfer_infos[room] = {}
|
self.transfer_infos[room] = {}
|
||||||
self.transfer_infos[room][agent_name] = TransferInfo.from_zmq(
|
self.transfer_infos[room][agent_name] = TransferInfo.from_zmq(
|
||||||
waiting_req_bytes
|
waiting_req_bytes
|
||||||
)
|
)
|
||||||
|
required_dst_info_num = self.transfer_infos[room][
|
||||||
|
agent_name
|
||||||
|
].required_dst_info_num
|
||||||
logger.debug(f"got info {room=} {agent_name=} {required_dst_info_num=}")
|
logger.debug(f"got info {room=} {agent_name=} {required_dst_info_num=}")
|
||||||
if len(self.transfer_infos[room]) == required_dst_info_num:
|
if len(self.transfer_infos[room]) == required_dst_info_num:
|
||||||
logger.debug(f"{room=} is bootstrapped")
|
logger.debug(f"{room=} is bootstrapped")
|
||||||
@@ -391,6 +422,7 @@ class NixlKVSender(BaseKVSender):
|
|||||||
self.chunk_id += 1
|
self.chunk_id += 1
|
||||||
if is_last:
|
if is_last:
|
||||||
self.has_sent = True
|
self.has_sent = True
|
||||||
|
del self.kv_mgr.request_status[self.bootstrap_room]
|
||||||
|
|
||||||
def poll(self) -> KVPoll:
|
def poll(self) -> KVPoll:
|
||||||
if not self.has_sent:
|
if not self.has_sent:
|
||||||
@@ -415,6 +447,7 @@ class NixlKVReceiver(CommonKVReceiver):
|
|||||||
data_parallel_rank: Optional[int] = None,
|
data_parallel_rank: Optional[int] = None,
|
||||||
):
|
):
|
||||||
self.started_transfer = False
|
self.started_transfer = False
|
||||||
|
self.conclude_state = None
|
||||||
super().__init__(mgr, bootstrap_addr, bootstrap_room, data_parallel_rank)
|
super().__init__(mgr, bootstrap_addr, bootstrap_room, data_parallel_rank)
|
||||||
|
|
||||||
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
|
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
|
||||||
@@ -426,17 +459,8 @@ class NixlKVReceiver(CommonKVReceiver):
|
|||||||
f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
|
f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
|
||||||
)
|
)
|
||||||
is_dummy = bootstrap_info["is_dummy"]
|
is_dummy = bootstrap_info["is_dummy"]
|
||||||
|
|
||||||
# TODO: send_kv_args earlier
|
|
||||||
packed_kv_data_ptrs = b"".join(
|
|
||||||
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
|
|
||||||
)
|
|
||||||
packed_aux_data_ptrs = b"".join(
|
|
||||||
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Sending to {self.prefill_server_url} with bootstrap room {self.bootstrap_room}"
|
f"Sending to {self.prefill_server_url} with bootstrap room {self.bootstrap_room} {is_dummy=}"
|
||||||
)
|
)
|
||||||
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
||||||
with lock:
|
with lock:
|
||||||
@@ -446,13 +470,9 @@ class NixlKVReceiver(CommonKVReceiver):
|
|||||||
str(self.bootstrap_room).encode("ascii"),
|
str(self.bootstrap_room).encode("ascii"),
|
||||||
get_local_ip_by_remote().encode("ascii"),
|
get_local_ip_by_remote().encode("ascii"),
|
||||||
str(self.kv_mgr.rank_port).encode("ascii"),
|
str(self.kv_mgr.rank_port).encode("ascii"),
|
||||||
self.kv_mgr.agent.get_agent_metadata(),
|
|
||||||
self.kv_mgr.agent.name.encode("ascii"),
|
self.kv_mgr.agent.name.encode("ascii"),
|
||||||
packed_kv_data_ptrs,
|
|
||||||
kv_indices.tobytes() if not is_dummy else b"",
|
kv_indices.tobytes() if not is_dummy else b"",
|
||||||
packed_aux_data_ptrs,
|
|
||||||
str(aux_index).encode("ascii"),
|
str(aux_index).encode("ascii"),
|
||||||
str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
|
|
||||||
str(self.required_dst_info_num).encode("ascii"),
|
str(self.required_dst_info_num).encode("ascii"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -460,17 +480,45 @@ class NixlKVReceiver(CommonKVReceiver):
|
|||||||
self.started_transfer = True
|
self.started_transfer = True
|
||||||
|
|
||||||
def poll(self) -> KVPoll:
|
def poll(self) -> KVPoll:
|
||||||
|
if self.conclude_state is not None:
|
||||||
|
return self.conclude_state
|
||||||
if not self.started_transfer:
|
if not self.started_transfer:
|
||||||
return KVPoll.WaitingForInput # type: ignore
|
return KVPoll.WaitingForInput # type: ignore
|
||||||
|
|
||||||
self.kv_mgr.update_transfer_status()
|
self.kv_mgr.update_transfer_status()
|
||||||
|
|
||||||
if self.kv_mgr.check_transfer_done(self.bootstrap_room): # type: ignore
|
if self.kv_mgr.check_transfer_done(self.bootstrap_room): # type: ignore
|
||||||
|
self.conclude_state = KVPoll.Success
|
||||||
|
del self.kv_mgr.transfer_statuses[self.bootstrap_room]
|
||||||
return KVPoll.Success # type: ignore
|
return KVPoll.Success # type: ignore
|
||||||
return KVPoll.WaitingForInput # type: ignore
|
return KVPoll.WaitingForInput # type: ignore
|
||||||
|
|
||||||
def _register_kv_args(self):
|
def _register_kv_args(self):
|
||||||
pass
|
for bootstrap_info in self.bootstrap_infos:
|
||||||
|
self.prefill_server_url = (
|
||||||
|
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
|
||||||
|
)
|
||||||
|
packed_kv_data_ptrs = b"".join(
|
||||||
|
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
|
||||||
|
)
|
||||||
|
packed_aux_data_ptrs = b"".join(
|
||||||
|
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
|
||||||
|
)
|
||||||
|
|
||||||
|
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
||||||
|
with lock:
|
||||||
|
sock.send_multipart(
|
||||||
|
[
|
||||||
|
GUARD,
|
||||||
|
"None".encode("ascii"),
|
||||||
|
get_local_ip_by_remote().encode("ascii"),
|
||||||
|
str(self.kv_mgr.rank_port).encode("ascii"),
|
||||||
|
self.kv_mgr.agent.name.encode("ascii"),
|
||||||
|
self.kv_mgr.agent.get_agent_metadata(),
|
||||||
|
packed_kv_data_ptrs,
|
||||||
|
packed_aux_data_ptrs,
|
||||||
|
str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
def failure_exception(self):
|
def failure_exception(self):
|
||||||
raise Exception("Fake KVReceiver Exception")
|
raise Exception("Fake KVReceiver Exception")
|
||||||
|
|||||||
Reference in New Issue
Block a user