feat: use D2D instead of H2H in pp (#7673)
Co-authored-by: alpha-baby <fujianhao1997@qq.com>
This commit is contained in:
@@ -699,18 +699,25 @@ class GroupCoordinator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Serialize object to tensor and get the size as well
|
# Serialize object to tensor and get the size as well
|
||||||
object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)
|
object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8).cuda(
|
||||||
|
device=torch.cuda.current_device()
|
||||||
|
)
|
||||||
|
|
||||||
size_tensor = torch.tensor(
|
size_tensor = torch.tensor(
|
||||||
[object_tensor.numel()], dtype=torch.long, device="cpu"
|
[object_tensor.numel()],
|
||||||
|
dtype=torch.long,
|
||||||
|
device=torch.cuda.current_device(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Send object size
|
# Send object size
|
||||||
|
torch.distributed.send(
|
||||||
torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group)
|
size_tensor, dst=self.ranks[dst], group=self.device_group
|
||||||
|
)
|
||||||
|
|
||||||
# Send object
|
# Send object
|
||||||
torch.distributed.send(object_tensor, dst=self.ranks[dst], group=self.cpu_group)
|
torch.distributed.send(
|
||||||
|
object_tensor, dst=self.ranks[dst], group=self.device_group
|
||||||
|
)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -724,29 +731,31 @@ class GroupCoordinator:
|
|||||||
src != self.rank_in_group
|
src != self.rank_in_group
|
||||||
), "Invalid source rank. Source rank is the same as the current rank."
|
), "Invalid source rank. Source rank is the same as the current rank."
|
||||||
|
|
||||||
size_tensor = torch.empty(1, dtype=torch.long, device="cpu")
|
size_tensor = torch.empty(
|
||||||
|
1, dtype=torch.long, device=torch.cuda.current_device()
|
||||||
|
)
|
||||||
|
|
||||||
# Receive object size
|
# Receive object size
|
||||||
rank_size = torch.distributed.recv(
|
rank_size = torch.distributed.recv(
|
||||||
size_tensor, src=self.ranks[src], group=self.cpu_group
|
size_tensor, src=self.ranks[src], group=self.device_group
|
||||||
)
|
)
|
||||||
|
|
||||||
# Tensor to receive serialized objects into.
|
# Tensor to receive serialized objects into.
|
||||||
object_tensor = torch.empty( # type: ignore[call-overload]
|
object_tensor = torch.empty( # type: ignore[call-overload]
|
||||||
size_tensor.item(), # type: ignore[arg-type]
|
size_tensor.item(), # type: ignore[arg-type]
|
||||||
dtype=torch.uint8,
|
dtype=torch.uint8,
|
||||||
device="cpu",
|
device=torch.cuda.current_device(),
|
||||||
)
|
)
|
||||||
|
|
||||||
rank_object = torch.distributed.recv(
|
rank_object = torch.distributed.recv(
|
||||||
object_tensor, src=self.ranks[src], group=self.cpu_group
|
object_tensor, src=self.ranks[src], group=self.device_group
|
||||||
)
|
)
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
rank_object == rank_size
|
rank_object == rank_size
|
||||||
), "Received object sender rank does not match the size sender rank."
|
), "Received object sender rank does not match the size sender rank."
|
||||||
|
|
||||||
obj = pickle.loads(object_tensor.numpy().tobytes())
|
obj = pickle.loads(object_tensor.cpu().numpy().tobytes())
|
||||||
|
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
@@ -857,14 +866,16 @@ class GroupCoordinator:
|
|||||||
dst = (self.rank_in_group + 1) % self.world_size
|
dst = (self.rank_in_group + 1) % self.world_size
|
||||||
assert dst < self.world_size, f"Invalid dst rank ({dst})"
|
assert dst < self.world_size, f"Invalid dst rank ({dst})"
|
||||||
|
|
||||||
metadata_list: List[Tuple[Any, Any]] = []
|
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
tensor_dict, dict
|
tensor_dict, dict
|
||||||
), f"Expecting a dictionary, got {type(tensor_dict)}"
|
), f"Expecting a dictionary, got {type(tensor_dict)}"
|
||||||
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
|
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
|
||||||
# `metadata_list` lives in CPU memory.
|
# Note: While switching to Device-to-Device (D2D) would introduce an extra
|
||||||
# `send_object_list` has serialization & deserialization,
|
# Device-to-Host (D2H) memory copy overhead for serialization, our benchmarks
|
||||||
# all happening on CPU. Therefore, we can use the CPU group.
|
# show better overall transmission performance with D2D due to:
|
||||||
|
# 1. Superior D2D transfer bandwidth
|
||||||
|
# 2. Ability to overlap send and recv operations
|
||||||
|
# Thus the net performance gain justifies this approach.
|
||||||
self.send_object(metadata_list, dst=dst)
|
self.send_object(metadata_list, dst=dst)
|
||||||
for tensor in tensor_list:
|
for tensor in tensor_list:
|
||||||
if tensor.numel() == 0:
|
if tensor.numel() == 0:
|
||||||
|
|||||||
@@ -928,7 +928,7 @@ class Scheduler(
|
|||||||
point_to_point_pyobj(
|
point_to_point_pyobj(
|
||||||
recv_reqs,
|
recv_reqs,
|
||||||
self.pp_rank * self.tp_size + dp_offset,
|
self.pp_rank * self.tp_size + dp_offset,
|
||||||
self.world_group.cpu_group,
|
self.world_group.device_group,
|
||||||
self.pp_rank * self.tp_size + dp_offset,
|
self.pp_rank * self.tp_size + dp_offset,
|
||||||
(self.pp_rank + 1) * self.tp_size + dp_offset,
|
(self.pp_rank + 1) * self.tp_size + dp_offset,
|
||||||
)
|
)
|
||||||
@@ -975,7 +975,7 @@ class Scheduler(
|
|||||||
recv_reqs = point_to_point_pyobj(
|
recv_reqs = point_to_point_pyobj(
|
||||||
[],
|
[],
|
||||||
self.pp_rank * self.tp_size + dp_offset,
|
self.pp_rank * self.tp_size + dp_offset,
|
||||||
self.world_group.cpu_group,
|
self.world_group.device_group,
|
||||||
(self.pp_rank - 1) * self.tp_size + dp_offset,
|
(self.pp_rank - 1) * self.tp_size + dp_offset,
|
||||||
self.pp_rank * self.tp_size + dp_offset,
|
self.pp_rank * self.tp_size + dp_offset,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1000,36 +1000,48 @@ def point_to_point_pyobj(
|
|||||||
src: int = 0,
|
src: int = 0,
|
||||||
dst: int = 1,
|
dst: int = 1,
|
||||||
):
|
):
|
||||||
"""Send data from src to dst in group."""
|
"""Send data from src to dst in group using DeviceToDevice communication."""
|
||||||
|
|
||||||
if rank == src:
|
if rank == src:
|
||||||
if len(data) == 0:
|
if len(data) == 0:
|
||||||
tensor_size = torch.tensor([0], dtype=torch.long)
|
tensor_size = torch.tensor(
|
||||||
|
[0], dtype=torch.long, device=torch.cuda.current_device()
|
||||||
|
)
|
||||||
dist.send(tensor_size, dst=dst, group=group)
|
dist.send(tensor_size, dst=dst, group=group)
|
||||||
else:
|
else:
|
||||||
serialized_data = pickle.dumps(data)
|
serialized_data = pickle.dumps(data)
|
||||||
size = len(serialized_data)
|
size = len(serialized_data)
|
||||||
tensor_data = torch.ByteTensor(
|
tensor_data = torch.ByteTensor(
|
||||||
np.frombuffer(serialized_data, dtype=np.uint8)
|
np.frombuffer(serialized_data, dtype=np.uint8)
|
||||||
|
).cuda(
|
||||||
|
device=torch.cuda.current_device()
|
||||||
|
) # Move to GPU
|
||||||
|
tensor_size = torch.tensor(
|
||||||
|
[size], dtype=torch.long, device=torch.cuda.current_device()
|
||||||
)
|
)
|
||||||
tensor_size = torch.tensor([size], dtype=torch.long)
|
|
||||||
|
|
||||||
dist.send(tensor_size, dst=dst, group=group)
|
dist.send(tensor_size, dst=dst, group=group)
|
||||||
dist.send(tensor_data, dst=dst, group=group)
|
dist.send(tensor_data, dst=dst, group=group)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
elif rank == dst:
|
elif rank == dst:
|
||||||
tensor_size = torch.tensor([0], dtype=torch.long)
|
tensor_size = torch.tensor(
|
||||||
|
[0], dtype=torch.long, device=torch.cuda.current_device()
|
||||||
|
)
|
||||||
dist.recv(tensor_size, src=src, group=group)
|
dist.recv(tensor_size, src=src, group=group)
|
||||||
size = tensor_size.item()
|
size = tensor_size.item()
|
||||||
|
|
||||||
if size == 0:
|
if size == 0:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
tensor_data = torch.empty(size, dtype=torch.uint8)
|
tensor_data = torch.empty(
|
||||||
|
size, dtype=torch.uint8, device=torch.cuda.current_device()
|
||||||
|
)
|
||||||
dist.recv(tensor_data, src=src, group=group)
|
dist.recv(tensor_data, src=src, group=group)
|
||||||
|
|
||||||
serialized_data = bytes(tensor_data.cpu().numpy())
|
serialized_data = bytes(
|
||||||
|
tensor_data.cpu().numpy()
|
||||||
|
) # Move back to host for deserialization
|
||||||
data = pickle.loads(serialized_data)
|
data = pickle.loads(serialized_data)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user