[BugFix]Fix bugs when initializing communication groups with dp on 300I Duo (#1478)
### What this PR does / why we need it? This PR fixes a bug that use broadcast with cpu_group when running dp. The `broadcast310p` patch will take effects for both cpu_group and device group, but we only need it for device group. Hence a wrapper is added to allow cpu_group use native torch broadcast and it solves the bug. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? With this PR, DP on 310p runs normally and generates reasonable answers. Signed-off-by: angazenn <zengyanjia@huawei.com> Co-authored-by: angazenn <zengyanjia@huawei.com>
This commit is contained in:
@@ -77,20 +77,28 @@ class NullHandle:
|
||||
|
||||
def communication_adaptation_310p():
|
||||
|
||||
def broadcast310p(tensor, src, group=None, async_op=False):
|
||||
rank = torch.distributed.get_rank(group)
|
||||
world_size = torch.distributed.get_world_size(group)
|
||||
tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
|
||||
tensor_list[rank] = tensor
|
||||
torch.distributed.all_gather(tensor_list, tensor, group=group)
|
||||
tensor[...] = tensor_list[src]
|
||||
if async_op:
|
||||
return NullHandle()
|
||||
else:
|
||||
return None
|
||||
def broadcast310p_wrapper(fn):
|
||||
|
||||
torch.distributed.broadcast = broadcast310p
|
||||
torch.distributed.distributed_c10d.broadcast = broadcast310p
|
||||
def broadcast310p(tensor, src, group=None, async_op=False):
|
||||
if tensor.device == torch.device('cpu'):
|
||||
return fn(tensor, src, group, async_op)
|
||||
rank = torch.distributed.get_rank(group)
|
||||
world_size = torch.distributed.get_world_size(group)
|
||||
tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
|
||||
tensor_list[rank] = tensor
|
||||
torch.distributed.all_gather(tensor_list, tensor, group=group)
|
||||
tensor[...] = tensor_list[src]
|
||||
if async_op:
|
||||
return NullHandle()
|
||||
else:
|
||||
return None
|
||||
|
||||
return broadcast310p
|
||||
|
||||
torch.distributed.broadcast = broadcast310p_wrapper(
|
||||
torch.distributed.broadcast)
|
||||
torch.distributed.distributed_c10d.broadcast = broadcast310p_wrapper(
|
||||
torch.distributed.distributed_c10d.broadcast)
|
||||
|
||||
def all_reduce_wrapper_310p(fn):
|
||||
|
||||
|
||||
Reference in New Issue
Block a user