[PD] bug fix: Update status if nixl receiver send a a dummy req. (#6720)
This commit is contained in:
@@ -53,26 +53,10 @@ class TransferInfo:
|
|||||||
required_dst_info_num: int
|
required_dst_info_num: int
|
||||||
|
|
||||||
def is_dummy(self):
|
def is_dummy(self):
|
||||||
return self.endpoint == ""
|
return self.dst_kv_indices.size == 0
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_zmq(cls, msg: List[bytes]):
|
def from_zmq(cls, msg: List[bytes]):
|
||||||
if len(msg) == 1:
|
|
||||||
# dummy msg
|
|
||||||
return cls(
|
|
||||||
room=int(msg[0].decode("ascii")),
|
|
||||||
endpoint="",
|
|
||||||
dst_port=0,
|
|
||||||
agent_metadata=b"",
|
|
||||||
agent_name="",
|
|
||||||
dst_kv_ptrs=[],
|
|
||||||
dst_kv_indices=np.array([], dtype=np.int64),
|
|
||||||
dst_aux_ptrs=[],
|
|
||||||
dst_aux_index=0,
|
|
||||||
dst_gpu_id=0,
|
|
||||||
required_dst_info_num=0,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return cls(
|
return cls(
|
||||||
room=int(msg[0].decode("ascii")),
|
room=int(msg[0].decode("ascii")),
|
||||||
endpoint=msg[1].decode("ascii"),
|
endpoint=msg[1].decode("ascii"),
|
||||||
@@ -278,7 +262,7 @@ class NixlKVManager(CommonKVManager):
|
|||||||
for req in reqs_to_be_processed:
|
for req in reqs_to_be_processed:
|
||||||
assert bootstrap_room == req.room
|
assert bootstrap_room == req.room
|
||||||
if req.is_dummy():
|
if req.is_dummy():
|
||||||
return []
|
continue
|
||||||
|
|
||||||
peer_name = self._add_remote(req.agent_name, req.agent_metadata)
|
peer_name = self._add_remote(req.agent_name, req.agent_metadata)
|
||||||
chunked_dst_kv_indice = req.dst_kv_indices[index_slice]
|
chunked_dst_kv_indice = req.dst_kv_indices[index_slice]
|
||||||
@@ -346,8 +330,7 @@ class NixlKVManager(CommonKVManager):
|
|||||||
), f"First message should be {GUARD}. Foreign traffic?"
|
), f"First message should be {GUARD}. Foreign traffic?"
|
||||||
waiting_req_bytes = waiting_req_bytes[1:]
|
waiting_req_bytes = waiting_req_bytes[1:]
|
||||||
room = waiting_req_bytes[0].decode("ascii")
|
room = waiting_req_bytes[0].decode("ascii")
|
||||||
if room == "None":
|
|
||||||
continue
|
|
||||||
required_dst_info_num = int(waiting_req_bytes[10].decode("ascii"))
|
required_dst_info_num = int(waiting_req_bytes[10].decode("ascii"))
|
||||||
room = int(room)
|
room = int(room)
|
||||||
agent_name = waiting_req_bytes[4].decode("ascii")
|
agent_name = waiting_req_bytes[4].decode("ascii")
|
||||||
@@ -438,19 +421,6 @@ class NixlKVReceiver(CommonKVReceiver):
|
|||||||
)
|
)
|
||||||
is_dummy = bootstrap_info["is_dummy"]
|
is_dummy = bootstrap_info["is_dummy"]
|
||||||
|
|
||||||
# TODO: just send "" for indices for dummy
|
|
||||||
if is_dummy:
|
|
||||||
# TODO: need to set success??
|
|
||||||
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
|
||||||
with lock:
|
|
||||||
sock.send_multipart(
|
|
||||||
[
|
|
||||||
GUARD,
|
|
||||||
str(self.bootstrap_room).encode("ascii"),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# TODO: send_kv_args earlier
|
# TODO: send_kv_args earlier
|
||||||
packed_kv_data_ptrs = b"".join(
|
packed_kv_data_ptrs = b"".join(
|
||||||
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
|
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
|
||||||
@@ -473,7 +443,7 @@ class NixlKVReceiver(CommonKVReceiver):
|
|||||||
self.kv_mgr.agent.get_agent_metadata(),
|
self.kv_mgr.agent.get_agent_metadata(),
|
||||||
self.kv_mgr.agent.name.encode("ascii"),
|
self.kv_mgr.agent.name.encode("ascii"),
|
||||||
packed_kv_data_ptrs,
|
packed_kv_data_ptrs,
|
||||||
kv_indices.tobytes(),
|
kv_indices.tobytes() if not is_dummy else b"",
|
||||||
packed_aux_data_ptrs,
|
packed_aux_data_ptrs,
|
||||||
str(aux_index).encode("ascii"),
|
str(aux_index).encode("ascii"),
|
||||||
str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
|
str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
|
||||||
|
|||||||
Reference in New Issue
Block a user