Fix different device type adjustment in PP (#7760)

This commit is contained in:
Qiaolin Yu
2025-07-15 19:37:14 -07:00
committed by GitHub
parent 7498522f7d
commit 3bc43c683e
4 changed files with 25 additions and 27 deletions

View File

@@ -1094,15 +1094,15 @@ def broadcast_pyobj(
rank: int,
dist_group: Optional[torch.distributed.ProcessGroup] = None,
src: int = 0,
force_cpu_device: bool = True,
device: Optional[str] = None,
):
"""Broadcast inputs from src rank to all other ranks with torch.dist backend.
The `rank` here refer to the source rank on global process group (regardless
of dist_group argument).
"""
device = torch.device(
"cuda" if torch.cuda.is_available() and not force_cpu_device else "cpu"
)
if device is None:
device = get_device()
if rank == src:
if len(data) == 0:
@@ -1142,44 +1142,38 @@ def point_to_point_pyobj(
group: Optional[torch.distributed.ProcessGroup] = None,
src: int = 0,
dst: int = 1,
device: Optional[str] = None,
):
"""Send data from src to dst in group using DeviceToDevice communication."""
if device is None:
device = get_device()
if rank == src:
if len(data) == 0:
tensor_size = torch.tensor(
[0], dtype=torch.long, device=torch.cuda.current_device()
)
tensor_size = torch.tensor([0], dtype=torch.long, device=device)
dist.send(tensor_size, dst=dst, group=group)
else:
serialized_data = pickle.dumps(data)
size = len(serialized_data)
tensor_data = torch.ByteTensor(
np.frombuffer(serialized_data, dtype=np.uint8)
).cuda(
device=torch.cuda.current_device()
) # Move to GPU
tensor_size = torch.tensor(
[size], dtype=torch.long, device=torch.cuda.current_device()
)
).to(
device=device
) # Move to Device
tensor_size = torch.tensor([size], dtype=torch.long, device=device)
dist.send(tensor_size, dst=dst, group=group)
dist.send(tensor_data, dst=dst, group=group)
return data
elif rank == dst:
tensor_size = torch.tensor(
[0], dtype=torch.long, device=torch.cuda.current_device()
)
tensor_size = torch.tensor([0], dtype=torch.long, device=device)
dist.recv(tensor_size, src=src, group=group)
size = tensor_size.item()
if size == 0:
return []
tensor_data = torch.empty(
size, dtype=torch.uint8, device=torch.cuda.current_device()
)
tensor_data = torch.empty(size, dtype=torch.uint8, device=device)
dist.recv(tensor_data, src=src, group=group)
serialized_data = bytes(