Revert "Fix different device type adjustment in PP" (#8141)
This commit is contained in:
@@ -1100,15 +1100,15 @@ def broadcast_pyobj(
|
||||
rank: int,
|
||||
dist_group: Optional[torch.distributed.ProcessGroup] = None,
|
||||
src: int = 0,
|
||||
device: Optional[str] = None,
|
||||
force_cpu_device: bool = True,
|
||||
):
|
||||
"""Broadcast inputs from src rank to all other ranks with torch.dist backend.
|
||||
The `rank` here refer to the source rank on global process group (regardless
|
||||
of dist_group argument).
|
||||
"""
|
||||
|
||||
if device is None:
|
||||
device = get_device()
|
||||
device = torch.device(
|
||||
"cuda" if torch.cuda.is_available() and not force_cpu_device else "cpu"
|
||||
)
|
||||
|
||||
if rank == src:
|
||||
if len(data) == 0:
|
||||
@@ -1148,38 +1148,44 @@ def point_to_point_pyobj(
|
||||
group: Optional[torch.distributed.ProcessGroup] = None,
|
||||
src: int = 0,
|
||||
dst: int = 1,
|
||||
device: Optional[str] = None,
|
||||
):
|
||||
"""Send data from src to dst in group using DeviceToDevice communication."""
|
||||
if device is None:
|
||||
device = get_device()
|
||||
|
||||
if rank == src:
|
||||
if len(data) == 0:
|
||||
tensor_size = torch.tensor([0], dtype=torch.long, device=device)
|
||||
tensor_size = torch.tensor(
|
||||
[0], dtype=torch.long, device=torch.cuda.current_device()
|
||||
)
|
||||
dist.send(tensor_size, dst=dst, group=group)
|
||||
else:
|
||||
serialized_data = pickle.dumps(data)
|
||||
size = len(serialized_data)
|
||||
tensor_data = torch.ByteTensor(
|
||||
np.frombuffer(serialized_data, dtype=np.uint8)
|
||||
).to(
|
||||
device=device
|
||||
) # Move to Device
|
||||
tensor_size = torch.tensor([size], dtype=torch.long, device=device)
|
||||
).cuda(
|
||||
device=torch.cuda.current_device()
|
||||
) # Move to GPU
|
||||
tensor_size = torch.tensor(
|
||||
[size], dtype=torch.long, device=torch.cuda.current_device()
|
||||
)
|
||||
|
||||
dist.send(tensor_size, dst=dst, group=group)
|
||||
dist.send(tensor_data, dst=dst, group=group)
|
||||
return data
|
||||
|
||||
elif rank == dst:
|
||||
tensor_size = torch.tensor([0], dtype=torch.long, device=device)
|
||||
tensor_size = torch.tensor(
|
||||
[0], dtype=torch.long, device=torch.cuda.current_device()
|
||||
)
|
||||
dist.recv(tensor_size, src=src, group=group)
|
||||
size = tensor_size.item()
|
||||
|
||||
if size == 0:
|
||||
return []
|
||||
|
||||
tensor_data = torch.empty(size, dtype=torch.uint8, device=device)
|
||||
tensor_data = torch.empty(
|
||||
size, dtype=torch.uint8, device=torch.cuda.current_device()
|
||||
)
|
||||
dist.recv(tensor_data, src=src, group=group)
|
||||
|
||||
serialized_data = bytes(
|
||||
|
||||
Reference in New Issue
Block a user