From 5c53cbaf2a7efcd09f2860eb6dff7412c29654b8 Mon Sep 17 00:00:00 2001 From: Angazenn <92204292+Angazenn@users.noreply.github.com> Date: Sat, 28 Jun 2025 16:07:52 +0800 Subject: [PATCH] [BugFix]Fix bugs when initializing communication groups with dp on 300I Duo (#1478) ### What this PR does / why we need it? This PR fixes a bug that use broadcast with cpu_group when running dp. The `broadcast310p` patch will take effects for both cpu_group and device group, but we only need it for device group. Hence a wrapper is added to allow cpu_group use native torch broadcast and it solves the bug. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? With this PR, DP on 310p runs normally and generates reasonable answers. Signed-off-by: angazenn Co-authored-by: angazenn --- .../patch_common/patch_distributed.py | 34 ++++++++++++------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/vllm_ascend/patch/platform/patch_common/patch_distributed.py b/vllm_ascend/patch/platform/patch_common/patch_distributed.py index b3db843..91f43a3 100644 --- a/vllm_ascend/patch/platform/patch_common/patch_distributed.py +++ b/vllm_ascend/patch/platform/patch_common/patch_distributed.py @@ -77,20 +77,28 @@ class NullHandle: def communication_adaptation_310p(): - def broadcast310p(tensor, src, group=None, async_op=False): - rank = torch.distributed.get_rank(group) - world_size = torch.distributed.get_world_size(group) - tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] - tensor_list[rank] = tensor - torch.distributed.all_gather(tensor_list, tensor, group=group) - tensor[...] = tensor_list[src] - if async_op: - return NullHandle() - else: - return None + def broadcast310p_wrapper(fn): - torch.distributed.broadcast = broadcast310p - torch.distributed.distributed_c10d.broadcast = broadcast310p + def broadcast310p(tensor, src, group=None, async_op=False): + if tensor.device == torch.device('cpu'): + return fn(tensor, src, group, async_op) + rank = torch.distributed.get_rank(group) + world_size = torch.distributed.get_world_size(group) + tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] + tensor_list[rank] = tensor + torch.distributed.all_gather(tensor_list, tensor, group=group) + tensor[...] = tensor_list[src] + if async_op: + return NullHandle() + else: + return None + + return broadcast310p + + torch.distributed.broadcast = broadcast310p_wrapper( + torch.distributed.broadcast) + torch.distributed.distributed_c10d.broadcast = broadcast310p_wrapper( + torch.distributed.distributed_c10d.broadcast) def all_reduce_wrapper_310p(fn):