From 4a6e7a66a02ac7ab966868933a59ddbb7c153c6b Mon Sep 17 00:00:00 2001 From: kk <43161300+kkHuang-amd@users.noreply.github.com> Date: Fri, 1 Aug 2025 07:15:43 +0800 Subject: [PATCH] Fix nan value generated after custom all reduce (#8532) --- .../device_communicators/custom_all_reduce.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py index a1d28f2fc..92da10112 100644 --- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py +++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py @@ -184,7 +184,7 @@ class CustomAllreduce: # 8*world_size bytes where world_size is at most 8. Allocating 8MB # is enough for 131072 such tuples. The largest model I've seen only # needs less than 10000 of registered tuples. - self.rank_data = torch.empty( + self.rank_data = torch.zeros( 8 * 1024 * 1024, dtype=torch.uint8, device=self.device ) self._ptr = ops.init_custom_ar( @@ -194,14 +194,14 @@ class CustomAllreduce: else: # meta data buffers need to be "uncached" for signal on MI200 self.meta = ops.allocate_meta_buffer(ops.meta_size() + max_size) - self.buffer = torch.empty(max_size, dtype=torch.uint8, device=self.device) + self.buffer = torch.zeros(max_size, dtype=torch.uint8, device=self.device) handle = ops.get_meta_buffer_ipc_handle(self.meta) shard_data = ( bytes(handle), # ipc handle to base ptr 0, # offset of base ptr ) handles, offsets = self._gather_ipc_meta(shard_data) - self.rank_data = torch.empty( + self.rank_data = torch.zeros( 8 * 1024 * 1024, dtype=torch.uint8, device=self.device ) self._ptr = ops.init_custom_ar( @@ -350,14 +350,14 @@ class CustomAllreduce: # or, in the context of cuda graphs, register_graph_buffers def all_reduce_reg(self, inp: torch.Tensor, out: torch.Tensor = None): if out is None: - out = torch.empty_like(inp) + out = torch.zeros_like(inp) ops.all_reduce_reg(self._ptr, inp, out) return out # all reduce, assuming inp tensor is NOT IPC registered def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None): if out is None: - out = torch.empty_like(inp) + out = torch.zeros_like(inp) ops.all_reduce_unreg(self._ptr, inp, self.buffer, out) return out @@ -375,7 +375,7 @@ class CustomAllreduce: buffer. """ if out is None: - out = torch.empty_like(inp) + out = torch.zeros_like(inp) if registered: ops.all_reduce(self._ptr, inp, out, 0, 0) else: @@ -398,7 +398,7 @@ class CustomAllreduce: else: # If warm up, mimic the allocation pattern since custom # allreduce is out-of-place. - return torch.empty_like(input) + return torch.zeros_like(input) else: if _is_hip: # note: outside of cuda graph context,