From 8311b07fb932768bfd04519da9f1bc5ae8bfa7f5 Mon Sep 17 00:00:00 2001 From: mlmz <54172054+minleminzui@users.noreply.github.com> Date: Sun, 13 Apr 2025 13:50:37 +0800 Subject: [PATCH] Fix: Ensure tensors for dist.broadcast match NCCL backend device (#5322) --- python/sglang/srt/utils.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 4a68dce74..64d2e66eb 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -848,31 +848,34 @@ def broadcast_pyobj( src: int = 0, ): """Broadcast inputs from rank=0 to all other ranks with torch.dist backend.""" + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"[broadcast_pyobj] rank={rank}, device={device}") if rank == 0: if len(data) == 0: - tensor_size = torch.tensor([0], dtype=torch.long) + tensor_size = torch.tensor([0], dtype=torch.long, device=device) dist.broadcast(tensor_size, src=src, group=dist_group) else: serialized_data = pickle.dumps(data) size = len(serialized_data) + tensor_data = torch.ByteTensor( np.frombuffer(serialized_data, dtype=np.uint8) - ) - tensor_size = torch.tensor([size], dtype=torch.long) + ).to(device) + tensor_size = torch.tensor([size], dtype=torch.long, device=device) dist.broadcast(tensor_size, src=src, group=dist_group) dist.broadcast(tensor_data, src=src, group=dist_group) return data else: - tensor_size = torch.tensor([0], dtype=torch.long) + tensor_size = torch.tensor([0], dtype=torch.long, device=device) dist.broadcast(tensor_size, src=src, group=dist_group) size = tensor_size.item() if size == 0: return [] - tensor_data = torch.empty(size, dtype=torch.uint8) + tensor_data = torch.empty(size, dtype=torch.uint8, device=device) dist.broadcast(tensor_data, src=src, group=dist_group) serialized_data = bytes(tensor_data.cpu().numpy())