Revert "Fix different device type adjustment in PP" (#8141)
This commit is contained in:
@@ -699,14 +699,14 @@ class GroupCoordinator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Serialize object to tensor and get the size as well
|
# Serialize object to tensor and get the size as well
|
||||||
object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8).to(
|
object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8).cuda(
|
||||||
device=self.device
|
device=torch.cuda.current_device()
|
||||||
)
|
)
|
||||||
|
|
||||||
size_tensor = torch.tensor(
|
size_tensor = torch.tensor(
|
||||||
[object_tensor.numel()],
|
[object_tensor.numel()],
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device=self.device,
|
device=torch.cuda.current_device(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Send object size
|
# Send object size
|
||||||
@@ -731,7 +731,9 @@ class GroupCoordinator:
|
|||||||
src != self.rank_in_group
|
src != self.rank_in_group
|
||||||
), "Invalid source rank. Source rank is the same as the current rank."
|
), "Invalid source rank. Source rank is the same as the current rank."
|
||||||
|
|
||||||
size_tensor = torch.empty(1, dtype=torch.long, device=self.device)
|
size_tensor = torch.empty(
|
||||||
|
1, dtype=torch.long, device=torch.cuda.current_device()
|
||||||
|
)
|
||||||
|
|
||||||
# Receive object size
|
# Receive object size
|
||||||
rank_size = torch.distributed.recv(
|
rank_size = torch.distributed.recv(
|
||||||
@@ -742,7 +744,7 @@ class GroupCoordinator:
|
|||||||
object_tensor = torch.empty( # type: ignore[call-overload]
|
object_tensor = torch.empty( # type: ignore[call-overload]
|
||||||
size_tensor.item(), # type: ignore[arg-type]
|
size_tensor.item(), # type: ignore[arg-type]
|
||||||
dtype=torch.uint8,
|
dtype=torch.uint8,
|
||||||
device=self.device,
|
device=torch.cuda.current_device(),
|
||||||
)
|
)
|
||||||
|
|
||||||
rank_object = torch.distributed.recv(
|
rank_object = torch.distributed.recv(
|
||||||
|
|||||||
@@ -975,7 +975,6 @@ class Scheduler(
|
|||||||
self.world_group.device_group,
|
self.world_group.device_group,
|
||||||
self.pp_rank * self.tp_size + dp_offset,
|
self.pp_rank * self.tp_size + dp_offset,
|
||||||
(self.pp_rank + 1) * self.tp_size + dp_offset,
|
(self.pp_rank + 1) * self.tp_size + dp_offset,
|
||||||
device=self.device,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# send out proxy tensors to the next stage
|
# send out proxy tensors to the next stage
|
||||||
@@ -1024,7 +1023,6 @@ class Scheduler(
|
|||||||
self.world_group.device_group,
|
self.world_group.device_group,
|
||||||
(self.pp_rank - 1) * self.tp_size + dp_offset,
|
(self.pp_rank - 1) * self.tp_size + dp_offset,
|
||||||
self.pp_rank * self.tp_size + dp_offset,
|
self.pp_rank * self.tp_size + dp_offset,
|
||||||
device=self.device,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
recv_reqs = None
|
recv_reqs = None
|
||||||
@@ -1055,7 +1053,6 @@ class Scheduler(
|
|||||||
self.attn_tp_group.rank,
|
self.attn_tp_group.rank,
|
||||||
self.attn_tp_cpu_group,
|
self.attn_tp_cpu_group,
|
||||||
src=self.attn_tp_group.ranks[0],
|
src=self.attn_tp_group.ranks[0],
|
||||||
device=self.device,
|
|
||||||
)
|
)
|
||||||
if self.tp_size != 1:
|
if self.tp_size != 1:
|
||||||
control_reqs = broadcast_pyobj(
|
control_reqs = broadcast_pyobj(
|
||||||
@@ -1063,7 +1060,6 @@ class Scheduler(
|
|||||||
self.tp_group.rank,
|
self.tp_group.rank,
|
||||||
self.tp_cpu_group,
|
self.tp_cpu_group,
|
||||||
src=self.tp_group.ranks[0],
|
src=self.tp_group.ranks[0],
|
||||||
device=self.device,
|
|
||||||
)
|
)
|
||||||
recv_reqs = work_reqs + control_reqs
|
recv_reqs = work_reqs + control_reqs
|
||||||
elif self.tp_size != 1:
|
elif self.tp_size != 1:
|
||||||
@@ -1072,7 +1068,6 @@ class Scheduler(
|
|||||||
self.tp_group.rank,
|
self.tp_group.rank,
|
||||||
self.tp_cpu_group,
|
self.tp_cpu_group,
|
||||||
src=self.tp_group.ranks[0],
|
src=self.tp_group.ranks[0],
|
||||||
device=self.device,
|
|
||||||
)
|
)
|
||||||
return recv_reqs
|
return recv_reqs
|
||||||
|
|
||||||
|
|||||||
@@ -144,7 +144,6 @@ class TpModelWorker:
|
|||||||
self.tp_size * self.pp_rank + tp_rank,
|
self.tp_size * self.pp_rank + tp_rank,
|
||||||
self.world_group.cpu_group,
|
self.world_group.cpu_group,
|
||||||
src=self.world_group.ranks[0],
|
src=self.world_group.ranks[0],
|
||||||
device=self.device,
|
|
||||||
)[0]
|
)[0]
|
||||||
set_random_seed(self.random_seed)
|
set_random_seed(self.random_seed)
|
||||||
|
|
||||||
|
|||||||
@@ -1100,15 +1100,15 @@ def broadcast_pyobj(
|
|||||||
rank: int,
|
rank: int,
|
||||||
dist_group: Optional[torch.distributed.ProcessGroup] = None,
|
dist_group: Optional[torch.distributed.ProcessGroup] = None,
|
||||||
src: int = 0,
|
src: int = 0,
|
||||||
device: Optional[str] = None,
|
force_cpu_device: bool = True,
|
||||||
):
|
):
|
||||||
"""Broadcast inputs from src rank to all other ranks with torch.dist backend.
|
"""Broadcast inputs from src rank to all other ranks with torch.dist backend.
|
||||||
The `rank` here refer to the source rank on global process group (regardless
|
The `rank` here refer to the source rank on global process group (regardless
|
||||||
of dist_group argument).
|
of dist_group argument).
|
||||||
"""
|
"""
|
||||||
|
device = torch.device(
|
||||||
if device is None:
|
"cuda" if torch.cuda.is_available() and not force_cpu_device else "cpu"
|
||||||
device = get_device()
|
)
|
||||||
|
|
||||||
if rank == src:
|
if rank == src:
|
||||||
if len(data) == 0:
|
if len(data) == 0:
|
||||||
@@ -1148,38 +1148,44 @@ def point_to_point_pyobj(
|
|||||||
group: Optional[torch.distributed.ProcessGroup] = None,
|
group: Optional[torch.distributed.ProcessGroup] = None,
|
||||||
src: int = 0,
|
src: int = 0,
|
||||||
dst: int = 1,
|
dst: int = 1,
|
||||||
device: Optional[str] = None,
|
|
||||||
):
|
):
|
||||||
"""Send data from src to dst in group using DeviceToDevice communication."""
|
"""Send data from src to dst in group using DeviceToDevice communication."""
|
||||||
if device is None:
|
|
||||||
device = get_device()
|
|
||||||
if rank == src:
|
if rank == src:
|
||||||
if len(data) == 0:
|
if len(data) == 0:
|
||||||
tensor_size = torch.tensor([0], dtype=torch.long, device=device)
|
tensor_size = torch.tensor(
|
||||||
|
[0], dtype=torch.long, device=torch.cuda.current_device()
|
||||||
|
)
|
||||||
dist.send(tensor_size, dst=dst, group=group)
|
dist.send(tensor_size, dst=dst, group=group)
|
||||||
else:
|
else:
|
||||||
serialized_data = pickle.dumps(data)
|
serialized_data = pickle.dumps(data)
|
||||||
size = len(serialized_data)
|
size = len(serialized_data)
|
||||||
tensor_data = torch.ByteTensor(
|
tensor_data = torch.ByteTensor(
|
||||||
np.frombuffer(serialized_data, dtype=np.uint8)
|
np.frombuffer(serialized_data, dtype=np.uint8)
|
||||||
).to(
|
).cuda(
|
||||||
device=device
|
device=torch.cuda.current_device()
|
||||||
) # Move to Device
|
) # Move to GPU
|
||||||
tensor_size = torch.tensor([size], dtype=torch.long, device=device)
|
tensor_size = torch.tensor(
|
||||||
|
[size], dtype=torch.long, device=torch.cuda.current_device()
|
||||||
|
)
|
||||||
|
|
||||||
dist.send(tensor_size, dst=dst, group=group)
|
dist.send(tensor_size, dst=dst, group=group)
|
||||||
dist.send(tensor_data, dst=dst, group=group)
|
dist.send(tensor_data, dst=dst, group=group)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
elif rank == dst:
|
elif rank == dst:
|
||||||
tensor_size = torch.tensor([0], dtype=torch.long, device=device)
|
tensor_size = torch.tensor(
|
||||||
|
[0], dtype=torch.long, device=torch.cuda.current_device()
|
||||||
|
)
|
||||||
dist.recv(tensor_size, src=src, group=group)
|
dist.recv(tensor_size, src=src, group=group)
|
||||||
size = tensor_size.item()
|
size = tensor_size.item()
|
||||||
|
|
||||||
if size == 0:
|
if size == 0:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
tensor_data = torch.empty(size, dtype=torch.uint8, device=device)
|
tensor_data = torch.empty(
|
||||||
|
size, dtype=torch.uint8, device=torch.cuda.current_device()
|
||||||
|
)
|
||||||
dist.recv(tensor_data, src=src, group=group)
|
dist.recv(tensor_data, src=src, group=group)
|
||||||
|
|
||||||
serialized_data = bytes(
|
serialized_data = bytes(
|
||||||
|
|||||||
Reference in New Issue
Block a user