Fix: Ensure tensors for dist.broadcast match NCCL backend device (#5322)
This commit is contained in:
@@ -848,31 +848,34 @@ def broadcast_pyobj(
|
|||||||
src: int = 0,
|
src: int = 0,
|
||||||
):
|
):
|
||||||
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
|
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
print(f"[broadcast_pyobj] rank={rank}, device={device}")
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
if len(data) == 0:
|
if len(data) == 0:
|
||||||
tensor_size = torch.tensor([0], dtype=torch.long)
|
tensor_size = torch.tensor([0], dtype=torch.long, device=device)
|
||||||
dist.broadcast(tensor_size, src=src, group=dist_group)
|
dist.broadcast(tensor_size, src=src, group=dist_group)
|
||||||
else:
|
else:
|
||||||
serialized_data = pickle.dumps(data)
|
serialized_data = pickle.dumps(data)
|
||||||
size = len(serialized_data)
|
size = len(serialized_data)
|
||||||
|
|
||||||
tensor_data = torch.ByteTensor(
|
tensor_data = torch.ByteTensor(
|
||||||
np.frombuffer(serialized_data, dtype=np.uint8)
|
np.frombuffer(serialized_data, dtype=np.uint8)
|
||||||
)
|
).to(device)
|
||||||
tensor_size = torch.tensor([size], dtype=torch.long)
|
tensor_size = torch.tensor([size], dtype=torch.long, device=device)
|
||||||
|
|
||||||
dist.broadcast(tensor_size, src=src, group=dist_group)
|
dist.broadcast(tensor_size, src=src, group=dist_group)
|
||||||
dist.broadcast(tensor_data, src=src, group=dist_group)
|
dist.broadcast(tensor_data, src=src, group=dist_group)
|
||||||
return data
|
return data
|
||||||
else:
|
else:
|
||||||
tensor_size = torch.tensor([0], dtype=torch.long)
|
tensor_size = torch.tensor([0], dtype=torch.long, device=device)
|
||||||
dist.broadcast(tensor_size, src=src, group=dist_group)
|
dist.broadcast(tensor_size, src=src, group=dist_group)
|
||||||
size = tensor_size.item()
|
size = tensor_size.item()
|
||||||
|
|
||||||
if size == 0:
|
if size == 0:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
tensor_data = torch.empty(size, dtype=torch.uint8)
|
tensor_data = torch.empty(size, dtype=torch.uint8, device=device)
|
||||||
dist.broadcast(tensor_data, src=src, group=dist_group)
|
dist.broadcast(tensor_data, src=src, group=dist_group)
|
||||||
|
|
||||||
serialized_data = bytes(tensor_data.cpu().numpy())
|
serialized_data = bytes(tensor_data.cpu().numpy())
|
||||||
|
|||||||
Reference in New Issue
Block a user