Fix GPU fault issue when run dsv3 with dp mode and enable torch-compile (#10361)
Co-authored-by: wunhuang <wunhuang@amd.com>
This commit is contained in:
@@ -119,6 +119,18 @@ class _DpGatheredBufferWrapper:
|
||||
def get_dp_global_num_tokens(cls) -> List[int]:
|
||||
return cls._global_num_tokens
|
||||
|
||||
@classmethod
|
||||
def get_dp_hidden_size(cls) -> int:
|
||||
return cls._hidden_size
|
||||
|
||||
@classmethod
|
||||
def get_dp_dtype(cls) -> torch.dtype:
|
||||
return cls._dtype
|
||||
|
||||
@classmethod
|
||||
def get_dp_device(cls) -> torch.device:
|
||||
return cls._device
|
||||
|
||||
|
||||
def set_dp_buffer_len(
|
||||
global_dp_buffer_len: int,
|
||||
@@ -150,6 +162,18 @@ def get_dp_global_num_tokens() -> List[int]:
|
||||
return _DpGatheredBufferWrapper.get_dp_global_num_tokens()
|
||||
|
||||
|
||||
def get_dp_hidden_size() -> int:
|
||||
return _DpGatheredBufferWrapper.get_dp_hidden_size()
|
||||
|
||||
|
||||
def get_dp_dtype() -> torch.dtype:
|
||||
return _DpGatheredBufferWrapper.get_dp_dtype()
|
||||
|
||||
|
||||
def get_dp_device() -> torch.device:
|
||||
return _DpGatheredBufferWrapper.get_dp_device()
|
||||
|
||||
|
||||
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
|
||||
if not enable_dp_attention:
|
||||
return tp_rank, tp_size, 0
|
||||
|
||||
@@ -35,6 +35,9 @@ from sglang.srt.layers.dp_attention import (
|
||||
get_attention_dp_rank,
|
||||
get_attention_dp_size,
|
||||
get_attention_tp_size,
|
||||
get_dp_device,
|
||||
get_dp_dtype,
|
||||
get_dp_hidden_size,
|
||||
get_global_dp_buffer,
|
||||
get_local_attention_dp_size,
|
||||
set_dp_buffer_len,
|
||||
@@ -187,16 +190,23 @@ class LogitsMetadata:
|
||||
self.dp_local_start_pos = dp_local_start_pos
|
||||
self.dp_local_num_tokens = dp_local_num_tokens
|
||||
|
||||
hidden_size = get_dp_hidden_size()
|
||||
dtype = get_dp_dtype()
|
||||
device = get_dp_device()
|
||||
|
||||
if self.global_num_tokens_for_logprob_cpu is not None:
|
||||
# create a smaller buffer to reduce peak memory usage
|
||||
self.global_dp_buffer_len = sum(self.global_num_tokens_for_logprob_cpu)
|
||||
else:
|
||||
self.global_dp_buffer_len = self.global_dp_buffer_len
|
||||
|
||||
set_dp_buffer_len(
|
||||
self.global_dp_buffer_len,
|
||||
self.dp_local_num_tokens,
|
||||
self.global_num_tokens_for_logprob_cpu,
|
||||
self.gathered_buffer = torch.empty(
|
||||
(
|
||||
self.global_dp_buffer_len,
|
||||
hidden_size,
|
||||
),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
||||
@@ -443,7 +453,7 @@ class LogitsProcessor(nn.Module):
|
||||
if self.do_tensor_parallel_all_gather_dp_attn:
|
||||
logits_metadata.compute_dp_attention_metadata()
|
||||
hidden_states, local_hidden_states = (
|
||||
get_global_dp_buffer(),
|
||||
logits_metadata.gathered_buffer,
|
||||
hidden_states,
|
||||
)
|
||||
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
|
||||
|
||||
Reference in New Issue
Block a user