diff --git a/vllm_ascend/patch/platform/patch_common/patch_distributed.py b/vllm_ascend/patch/platform/patch_common/patch_distributed.py index b3db843..91f43a3 100644 --- a/vllm_ascend/patch/platform/patch_common/patch_distributed.py +++ b/vllm_ascend/patch/platform/patch_common/patch_distributed.py @@ -77,20 +77,28 @@ class NullHandle: def communication_adaptation_310p(): - def broadcast310p(tensor, src, group=None, async_op=False): - rank = torch.distributed.get_rank(group) - world_size = torch.distributed.get_world_size(group) - tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] - tensor_list[rank] = tensor - torch.distributed.all_gather(tensor_list, tensor, group=group) - tensor[...] = tensor_list[src] - if async_op: - return NullHandle() - else: - return None + def broadcast310p_wrapper(fn): - torch.distributed.broadcast = broadcast310p - torch.distributed.distributed_c10d.broadcast = broadcast310p + def broadcast310p(tensor, src, group=None, async_op=False): + if tensor.device == torch.device('cpu'): + return fn(tensor, src, group, async_op) + rank = torch.distributed.get_rank(group) + world_size = torch.distributed.get_world_size(group) + tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] + tensor_list[rank] = tensor + torch.distributed.all_gather(tensor_list, tensor, group=group) + tensor[...] = tensor_list[src] + if async_op: + return NullHandle() + else: + return None + + return broadcast310p + + torch.distributed.broadcast = broadcast310p_wrapper( + torch.distributed.broadcast) + torch.distributed.distributed_c10d.broadcast = broadcast310p_wrapper( + torch.distributed.distributed_c10d.broadcast) def all_reduce_wrapper_310p(fn):