Fix nan value generated after custom all reduce (#8532)
This commit is contained in:
@@ -184,7 +184,7 @@ class CustomAllreduce:
|
|||||||
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
|
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
|
||||||
# is enough for 131072 such tuples. The largest model I've seen only
|
# is enough for 131072 such tuples. The largest model I've seen only
|
||||||
# needs less than 10000 of registered tuples.
|
# needs less than 10000 of registered tuples.
|
||||||
self.rank_data = torch.empty(
|
self.rank_data = torch.zeros(
|
||||||
8 * 1024 * 1024, dtype=torch.uint8, device=self.device
|
8 * 1024 * 1024, dtype=torch.uint8, device=self.device
|
||||||
)
|
)
|
||||||
self._ptr = ops.init_custom_ar(
|
self._ptr = ops.init_custom_ar(
|
||||||
@@ -194,14 +194,14 @@ class CustomAllreduce:
|
|||||||
else:
|
else:
|
||||||
# meta data buffers need to be "uncached" for signal on MI200
|
# meta data buffers need to be "uncached" for signal on MI200
|
||||||
self.meta = ops.allocate_meta_buffer(ops.meta_size() + max_size)
|
self.meta = ops.allocate_meta_buffer(ops.meta_size() + max_size)
|
||||||
self.buffer = torch.empty(max_size, dtype=torch.uint8, device=self.device)
|
self.buffer = torch.zeros(max_size, dtype=torch.uint8, device=self.device)
|
||||||
handle = ops.get_meta_buffer_ipc_handle(self.meta)
|
handle = ops.get_meta_buffer_ipc_handle(self.meta)
|
||||||
shard_data = (
|
shard_data = (
|
||||||
bytes(handle), # ipc handle to base ptr
|
bytes(handle), # ipc handle to base ptr
|
||||||
0, # offset of base ptr
|
0, # offset of base ptr
|
||||||
)
|
)
|
||||||
handles, offsets = self._gather_ipc_meta(shard_data)
|
handles, offsets = self._gather_ipc_meta(shard_data)
|
||||||
self.rank_data = torch.empty(
|
self.rank_data = torch.zeros(
|
||||||
8 * 1024 * 1024, dtype=torch.uint8, device=self.device
|
8 * 1024 * 1024, dtype=torch.uint8, device=self.device
|
||||||
)
|
)
|
||||||
self._ptr = ops.init_custom_ar(
|
self._ptr = ops.init_custom_ar(
|
||||||
@@ -350,14 +350,14 @@ class CustomAllreduce:
|
|||||||
# or, in the context of cuda graphs, register_graph_buffers
|
# or, in the context of cuda graphs, register_graph_buffers
|
||||||
def all_reduce_reg(self, inp: torch.Tensor, out: torch.Tensor = None):
|
def all_reduce_reg(self, inp: torch.Tensor, out: torch.Tensor = None):
|
||||||
if out is None:
|
if out is None:
|
||||||
out = torch.empty_like(inp)
|
out = torch.zeros_like(inp)
|
||||||
ops.all_reduce_reg(self._ptr, inp, out)
|
ops.all_reduce_reg(self._ptr, inp, out)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
# all reduce, assuming inp tensor is NOT IPC registered
|
# all reduce, assuming inp tensor is NOT IPC registered
|
||||||
def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None):
|
def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None):
|
||||||
if out is None:
|
if out is None:
|
||||||
out = torch.empty_like(inp)
|
out = torch.zeros_like(inp)
|
||||||
ops.all_reduce_unreg(self._ptr, inp, self.buffer, out)
|
ops.all_reduce_unreg(self._ptr, inp, self.buffer, out)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@@ -375,7 +375,7 @@ class CustomAllreduce:
|
|||||||
buffer.
|
buffer.
|
||||||
"""
|
"""
|
||||||
if out is None:
|
if out is None:
|
||||||
out = torch.empty_like(inp)
|
out = torch.zeros_like(inp)
|
||||||
if registered:
|
if registered:
|
||||||
ops.all_reduce(self._ptr, inp, out, 0, 0)
|
ops.all_reduce(self._ptr, inp, out, 0, 0)
|
||||||
else:
|
else:
|
||||||
@@ -398,7 +398,7 @@ class CustomAllreduce:
|
|||||||
else:
|
else:
|
||||||
# If warm up, mimic the allocation pattern since custom
|
# If warm up, mimic the allocation pattern since custom
|
||||||
# allreduce is out-of-place.
|
# allreduce is out-of-place.
|
||||||
return torch.empty_like(input)
|
return torch.zeros_like(input)
|
||||||
else:
|
else:
|
||||||
if _is_hip:
|
if _is_hip:
|
||||||
# note: outside of cuda graph context,
|
# note: outside of cuda graph context,
|
||||||
|
|||||||
Reference in New Issue
Block a user