[BugFix]Fix bugs when initializing communication groups with dp on 300I Duo (#1478)
### What this PR does / why we need it? This PR fixes a bug that use broadcast with cpu_group when running dp. The `broadcast310p` patch will take effects for both cpu_group and device group, but we only need it for device group. Hence a wrapper is added to allow cpu_group use native torch broadcast and it solves the bug. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? With this PR, DP on 310p runs normally and generates reasonable answers. Signed-off-by: angazenn <zengyanjia@huawei.com> Co-authored-by: angazenn <zengyanjia@huawei.com>
This commit is contained in:
@@ -77,7 +77,11 @@ class NullHandle:
|
|||||||
|
|
||||||
def communication_adaptation_310p():
|
def communication_adaptation_310p():
|
||||||
|
|
||||||
|
def broadcast310p_wrapper(fn):
|
||||||
|
|
||||||
def broadcast310p(tensor, src, group=None, async_op=False):
|
def broadcast310p(tensor, src, group=None, async_op=False):
|
||||||
|
if tensor.device == torch.device('cpu'):
|
||||||
|
return fn(tensor, src, group, async_op)
|
||||||
rank = torch.distributed.get_rank(group)
|
rank = torch.distributed.get_rank(group)
|
||||||
world_size = torch.distributed.get_world_size(group)
|
world_size = torch.distributed.get_world_size(group)
|
||||||
tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
|
tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
|
||||||
@@ -89,8 +93,12 @@ def communication_adaptation_310p():
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
torch.distributed.broadcast = broadcast310p
|
return broadcast310p
|
||||||
torch.distributed.distributed_c10d.broadcast = broadcast310p
|
|
||||||
|
torch.distributed.broadcast = broadcast310p_wrapper(
|
||||||
|
torch.distributed.broadcast)
|
||||||
|
torch.distributed.distributed_c10d.broadcast = broadcast310p_wrapper(
|
||||||
|
torch.distributed.distributed_c10d.broadcast)
|
||||||
|
|
||||||
def all_reduce_wrapper_310p(fn):
|
def all_reduce_wrapper_310p(fn):
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user