feat: use D2D instead of H2H in pp (#7673)

Co-authored-by: alpha-baby <fujianhao1997@qq.com>
This commit is contained in:
TianyuZhang1214
2025-07-04 01:58:50 +08:00
committed by GitHub
parent 264dc6e744
commit 0099172327
3 changed files with 45 additions and 22 deletions

View File

@@ -699,18 +699,25 @@ class GroupCoordinator:
)
# Serialize object to tensor and get the size as well
object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)
object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8).cuda(
device=torch.cuda.current_device()
)
size_tensor = torch.tensor(
[object_tensor.numel()], dtype=torch.long, device="cpu"
[object_tensor.numel()],
dtype=torch.long,
device=torch.cuda.current_device(),
)
# Send object size
torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group)
torch.distributed.send(
size_tensor, dst=self.ranks[dst], group=self.device_group
)
# Send object
torch.distributed.send(object_tensor, dst=self.ranks[dst], group=self.cpu_group)
torch.distributed.send(
object_tensor, dst=self.ranks[dst], group=self.device_group
)
return None
@@ -724,29 +731,31 @@ class GroupCoordinator:
src != self.rank_in_group
), "Invalid source rank. Source rank is the same as the current rank."
size_tensor = torch.empty(1, dtype=torch.long, device="cpu")
size_tensor = torch.empty(
1, dtype=torch.long, device=torch.cuda.current_device()
)
# Receive object size
rank_size = torch.distributed.recv(
size_tensor, src=self.ranks[src], group=self.cpu_group
size_tensor, src=self.ranks[src], group=self.device_group
)
# Tensor to receive serialized objects into.
object_tensor = torch.empty( # type: ignore[call-overload]
size_tensor.item(), # type: ignore[arg-type]
dtype=torch.uint8,
device="cpu",
device=torch.cuda.current_device(),
)
rank_object = torch.distributed.recv(
object_tensor, src=self.ranks[src], group=self.cpu_group
object_tensor, src=self.ranks[src], group=self.device_group
)
assert (
rank_object == rank_size
), "Received object sender rank does not match the size sender rank."
obj = pickle.loads(object_tensor.numpy().tobytes())
obj = pickle.loads(object_tensor.cpu().numpy().tobytes())
return obj
@@ -857,14 +866,16 @@ class GroupCoordinator:
dst = (self.rank_in_group + 1) % self.world_size
assert dst < self.world_size, f"Invalid dst rank ({dst})"
metadata_list: List[Tuple[Any, Any]] = []
assert isinstance(
tensor_dict, dict
), f"Expecting a dictionary, got {type(tensor_dict)}"
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
# `metadata_list` lives in CPU memory.
# `send_object_list` has serialization & deserialization,
# all happening on CPU. Therefore, we can use the CPU group.
# Note: While switching to Device-to-Device (D2D) would introduce an extra
# Device-to-Host (D2H) memory copy overhead for serialization, our benchmarks
# show better overall transmission performance with D2D due to:
# 1. Superior D2D transfer bandwidth
# 2. Ability to overlap send and recv operations
# Thus the net performance gain justifies this approach.
self.send_object(metadata_list, dst=dst)
for tensor in tensor_list:
if tensor.numel() == 0: