[DP Attention] Refactor: adding some utility functions (#9136)

This commit is contained in:
Cheng Wan
2025-08-13 21:08:06 -07:00
committed by GitHub
parent b3363cc1aa
commit b87aacb5c5
21 changed files with 216 additions and 159 deletions

View File

@@ -34,9 +34,10 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
)
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
from sglang.srt.layers.dp_attention import (
DPPaddingMode,
DpPaddingMode,
get_attention_tp_rank,
get_attention_tp_size,
set_dp_buffer_len,
)
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.torchao_utils import save_gemlite_cache
@@ -349,30 +350,15 @@ class CudaGraphRunner:
self.global_num_tokens_for_logprob_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
)
self.gathered_buffer = torch.zeros(
(
self.max_num_token * self.dp_size,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)
else:
assert self.require_attn_tp_gather
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
self.global_num_tokens_for_logprob_gpu = torch.zeros(
(1,), dtype=torch.int32
)
self.gathered_buffer = torch.zeros(
(
self.max_num_token,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)
else:
self.global_num_tokens_gpu = None
self.global_num_tokens_for_logprob_gpu = None
self.gathered_buffer = None
self.custom_mask = torch.ones(
(
@@ -556,7 +542,7 @@ class CudaGraphRunner:
device=input_ids.device,
)
)
gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size]
global_dp_buffer_len = num_tokens * self.dp_size
elif self.require_attn_tp_gather:
self.global_num_tokens_gpu.copy_(
torch.tensor(
@@ -572,9 +558,9 @@ class CudaGraphRunner:
device=input_ids.device,
)
)
gathered_buffer = self.gathered_buffer[:num_tokens]
global_dp_buffer_len = num_tokens
else:
gathered_buffer = None
global_dp_buffer_len = None
spec_info = self.get_spec_info(num_tokens)
if self.capture_hidden_mode != CaptureHiddenMode.FULL:
@@ -607,8 +593,8 @@ class CudaGraphRunner:
positions=positions,
global_num_tokens_gpu=self.global_num_tokens_gpu,
global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu,
dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(),
gathered_buffer=gathered_buffer,
dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
global_dp_buffer_len=global_dp_buffer_len,
mrope_positions=mrope_positions,
spec_algorithm=self.model_runner.spec_algorithm,
spec_info=spec_info,
@@ -637,6 +623,7 @@ class CudaGraphRunner:
def run_once():
# Clean intermediate result cache for DP attention
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
set_dp_buffer_len(global_dp_buffer_len, num_tokens)
kwargs = {}
if (

View File

@@ -40,9 +40,10 @@ import triton.language as tl
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
from sglang.srt.layers.dp_attention import (
DPPaddingMode,
DpPaddingMode,
get_attention_dp_rank,
get_attention_tp_size,
set_dp_buffer_len,
)
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
from sglang.srt.utils import (
@@ -274,13 +275,13 @@ class ForwardBatch:
global_num_tokens_for_logprob_cpu: Optional[List[int]] = None
global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None
# The padding mode for DP attention
dp_padding_mode: Optional[DPPaddingMode] = None
dp_padding_mode: Optional[DpPaddingMode] = None
# for extend, local start pos and num tokens is different in logits processor
# this will be computed in get_dp_local_info
# this will be recomputed in LogitsMetadata.from_forward_batch
dp_local_start_pos: Optional[torch.Tensor] = None # cached info at runtime
dp_local_num_tokens: Optional[torch.Tensor] = None # cached info at runtime
gathered_buffer: Optional[torch.Tensor] = None
global_dp_buffer_len: Optional[int] = None
is_extend_in_batch: bool = False
can_run_dp_cuda_graph: bool = False
global_forward_mode: Optional[ForwardMode] = None
@@ -628,7 +629,7 @@ class ForwardBatch:
(global_num_tokens[i] - 1) // attn_tp_size + 1
) * attn_tp_size
dp_padding_mode = DPPaddingMode.get_dp_padding_mode(global_num_tokens)
dp_padding_mode = DpPaddingMode.get_dp_padding_mode(global_num_tokens)
self.dp_padding_mode = dp_padding_mode
if dp_padding_mode.is_max_len():
@@ -642,17 +643,14 @@ class ForwardBatch:
else:
buffer_len = sum(global_num_tokens)
self.gathered_buffer = torch.zeros(
(buffer_len, model_runner.model_config.hidden_size),
dtype=model_runner.dtype,
device=model_runner.device,
)
if len(global_num_tokens) > 1:
num_tokens = global_num_tokens[get_attention_dp_rank()]
else:
num_tokens = global_num_tokens[0]
self.global_dp_buffer_len = buffer_len
set_dp_buffer_len(buffer_len, num_tokens)
bs = self.batch_size
if self.forward_mode.is_decode():

View File

@@ -603,12 +603,8 @@ class ModelRunner:
duplicate_tp_group=self.server_args.enable_pdmux,
)
initialize_dp_attention(
enable_dp_attention=self.server_args.enable_dp_attention,
tp_rank=self.tp_rank,
tp_size=self.tp_size,
dp_size=self.server_args.dp_size,
moe_dense_tp_size=self.server_args.moe_dense_tp_size,
pp_size=self.server_args.pp_size,
server_args=self.server_args,
model_config=self.model_config,
)
min_per_gpu_memory = get_available_gpu_memory(