From 78b7465cad01f97653a1384960e908c1bc9cfe0b Mon Sep 17 00:00:00 2001 From: kk <43161300+kkHuang-amd@users.noreply.github.com> Date: Sat, 13 Sep 2025 06:05:51 +0800 Subject: [PATCH] Fix GPU fault issue when run dsv3 with dp mode and enable torch-compile (#10361) Co-authored-by: wunhuang --- python/sglang/srt/layers/dp_attention.py | 24 ++++++++++++++++++++ python/sglang/srt/layers/logits_processor.py | 20 ++++++++++++---- 2 files changed, 39 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/layers/dp_attention.py b/python/sglang/srt/layers/dp_attention.py index 1250636eb..b7feccdb6 100644 --- a/python/sglang/srt/layers/dp_attention.py +++ b/python/sglang/srt/layers/dp_attention.py @@ -119,6 +119,18 @@ class _DpGatheredBufferWrapper: def get_dp_global_num_tokens(cls) -> List[int]: return cls._global_num_tokens + @classmethod + def get_dp_hidden_size(cls) -> int: + return cls._hidden_size + + @classmethod + def get_dp_dtype(cls) -> torch.dtype: + return cls._dtype + + @classmethod + def get_dp_device(cls) -> torch.device: + return cls._device + def set_dp_buffer_len( global_dp_buffer_len: int, @@ -150,6 +162,18 @@ def get_dp_global_num_tokens() -> List[int]: return _DpGatheredBufferWrapper.get_dp_global_num_tokens() +def get_dp_hidden_size() -> int: + return _DpGatheredBufferWrapper.get_dp_hidden_size() + + +def get_dp_dtype() -> torch.dtype: + return _DpGatheredBufferWrapper.get_dp_dtype() + + +def get_dp_device() -> torch.device: + return _DpGatheredBufferWrapper.get_dp_device() + + def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size): if not enable_dp_attention: return tp_rank, tp_size, 0 diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index f6603907a..d465baeb4 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -35,6 +35,9 @@ from sglang.srt.layers.dp_attention import ( get_attention_dp_rank, get_attention_dp_size, get_attention_tp_size, + get_dp_device, + get_dp_dtype, + get_dp_hidden_size, get_global_dp_buffer, get_local_attention_dp_size, set_dp_buffer_len, @@ -187,16 +190,23 @@ class LogitsMetadata: self.dp_local_start_pos = dp_local_start_pos self.dp_local_num_tokens = dp_local_num_tokens + hidden_size = get_dp_hidden_size() + dtype = get_dp_dtype() + device = get_dp_device() + if self.global_num_tokens_for_logprob_cpu is not None: # create a smaller buffer to reduce peak memory usage self.global_dp_buffer_len = sum(self.global_num_tokens_for_logprob_cpu) else: self.global_dp_buffer_len = self.global_dp_buffer_len - set_dp_buffer_len( - self.global_dp_buffer_len, - self.dp_local_num_tokens, - self.global_num_tokens_for_logprob_cpu, + self.gathered_buffer = torch.empty( + ( + self.global_dp_buffer_len, + hidden_size, + ), + dtype=dtype, + device=device, ) @@ -443,7 +453,7 @@ class LogitsProcessor(nn.Module): if self.do_tensor_parallel_all_gather_dp_attn: logits_metadata.compute_dp_attention_metadata() hidden_states, local_hidden_states = ( - get_global_dp_buffer(), + logits_metadata.gathered_buffer, hidden_states, ) dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)