[PD] bug fix: Update status if nixl receiver send a a dummy req. (#6720)
This commit is contained in:
@@ -53,39 +53,23 @@ class TransferInfo:
|
||||
required_dst_info_num: int
|
||||
|
||||
def is_dummy(self):
|
||||
return self.endpoint == ""
|
||||
return self.dst_kv_indices.size == 0
|
||||
|
||||
@classmethod
|
||||
def from_zmq(cls, msg: List[bytes]):
|
||||
if len(msg) == 1:
|
||||
# dummy msg
|
||||
return cls(
|
||||
room=int(msg[0].decode("ascii")),
|
||||
endpoint="",
|
||||
dst_port=0,
|
||||
agent_metadata=b"",
|
||||
agent_name="",
|
||||
dst_kv_ptrs=[],
|
||||
dst_kv_indices=np.array([], dtype=np.int64),
|
||||
dst_aux_ptrs=[],
|
||||
dst_aux_index=0,
|
||||
dst_gpu_id=0,
|
||||
required_dst_info_num=0,
|
||||
)
|
||||
else:
|
||||
return cls(
|
||||
room=int(msg[0].decode("ascii")),
|
||||
endpoint=msg[1].decode("ascii"),
|
||||
dst_port=int(msg[2].decode("ascii")),
|
||||
agent_metadata=msg[3],
|
||||
agent_name=msg[4].decode("ascii"),
|
||||
dst_kv_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
|
||||
dst_kv_indices=np.frombuffer(msg[6], dtype=np.int64),
|
||||
dst_aux_ptrs=list(struct.unpack(f"{len(msg[7])//8}Q", msg[7])),
|
||||
dst_aux_index=int(msg[8].decode("ascii")),
|
||||
dst_gpu_id=int(msg[9].decode("ascii")),
|
||||
required_dst_info_num=int(msg[10].decode("ascii")),
|
||||
)
|
||||
return cls(
|
||||
room=int(msg[0].decode("ascii")),
|
||||
endpoint=msg[1].decode("ascii"),
|
||||
dst_port=int(msg[2].decode("ascii")),
|
||||
agent_metadata=msg[3],
|
||||
agent_name=msg[4].decode("ascii"),
|
||||
dst_kv_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
|
||||
dst_kv_indices=np.frombuffer(msg[6], dtype=np.int64),
|
||||
dst_aux_ptrs=list(struct.unpack(f"{len(msg[7])//8}Q", msg[7])),
|
||||
dst_aux_index=int(msg[8].decode("ascii")),
|
||||
dst_gpu_id=int(msg[9].decode("ascii")),
|
||||
required_dst_info_num=int(msg[10].decode("ascii")),
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@@ -278,7 +262,7 @@ class NixlKVManager(CommonKVManager):
|
||||
for req in reqs_to_be_processed:
|
||||
assert bootstrap_room == req.room
|
||||
if req.is_dummy():
|
||||
return []
|
||||
continue
|
||||
|
||||
peer_name = self._add_remote(req.agent_name, req.agent_metadata)
|
||||
chunked_dst_kv_indice = req.dst_kv_indices[index_slice]
|
||||
@@ -346,8 +330,7 @@ class NixlKVManager(CommonKVManager):
|
||||
), f"First message should be {GUARD}. Foreign traffic?"
|
||||
waiting_req_bytes = waiting_req_bytes[1:]
|
||||
room = waiting_req_bytes[0].decode("ascii")
|
||||
if room == "None":
|
||||
continue
|
||||
|
||||
required_dst_info_num = int(waiting_req_bytes[10].decode("ascii"))
|
||||
room = int(room)
|
||||
agent_name = waiting_req_bytes[4].decode("ascii")
|
||||
@@ -438,19 +421,6 @@ class NixlKVReceiver(CommonKVReceiver):
|
||||
)
|
||||
is_dummy = bootstrap_info["is_dummy"]
|
||||
|
||||
# TODO: just send "" for indices for dummy
|
||||
if is_dummy:
|
||||
# TODO: need to set success??
|
||||
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
||||
with lock:
|
||||
sock.send_multipart(
|
||||
[
|
||||
GUARD,
|
||||
str(self.bootstrap_room).encode("ascii"),
|
||||
]
|
||||
)
|
||||
continue
|
||||
|
||||
# TODO: send_kv_args earlier
|
||||
packed_kv_data_ptrs = b"".join(
|
||||
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
|
||||
@@ -473,7 +443,7 @@ class NixlKVReceiver(CommonKVReceiver):
|
||||
self.kv_mgr.agent.get_agent_metadata(),
|
||||
self.kv_mgr.agent.name.encode("ascii"),
|
||||
packed_kv_data_ptrs,
|
||||
kv_indices.tobytes(),
|
||||
kv_indices.tobytes() if not is_dummy else b"",
|
||||
packed_aux_data_ptrs,
|
||||
str(aux_index).encode("ascii"),
|
||||
str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
|
||||
|
||||
Reference in New Issue
Block a user