Revert "Fix different device type adjustment in PP" (#8141)
This commit is contained in:
@@ -699,14 +699,14 @@ class GroupCoordinator:
|
||||
)
|
||||
|
||||
# Serialize object to tensor and get the size as well
|
||||
object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8).to(
|
||||
device=self.device
|
||||
object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8).cuda(
|
||||
device=torch.cuda.current_device()
|
||||
)
|
||||
|
||||
size_tensor = torch.tensor(
|
||||
[object_tensor.numel()],
|
||||
dtype=torch.long,
|
||||
device=self.device,
|
||||
device=torch.cuda.current_device(),
|
||||
)
|
||||
|
||||
# Send object size
|
||||
@@ -731,7 +731,9 @@ class GroupCoordinator:
|
||||
src != self.rank_in_group
|
||||
), "Invalid source rank. Source rank is the same as the current rank."
|
||||
|
||||
size_tensor = torch.empty(1, dtype=torch.long, device=self.device)
|
||||
size_tensor = torch.empty(
|
||||
1, dtype=torch.long, device=torch.cuda.current_device()
|
||||
)
|
||||
|
||||
# Receive object size
|
||||
rank_size = torch.distributed.recv(
|
||||
@@ -742,7 +744,7 @@ class GroupCoordinator:
|
||||
object_tensor = torch.empty( # type: ignore[call-overload]
|
||||
size_tensor.item(), # type: ignore[arg-type]
|
||||
dtype=torch.uint8,
|
||||
device=self.device,
|
||||
device=torch.cuda.current_device(),
|
||||
)
|
||||
|
||||
rank_object = torch.distributed.recv(
|
||||
|
||||
Reference in New Issue
Block a user