[DP Attention] Refactor: adding some utility functions (#9136)

This commit is contained in:
Cheng Wan
2025-08-13 21:08:06 -07:00
committed by GitHub
parent b3363cc1aa
commit b87aacb5c5
21 changed files with 216 additions and 159 deletions

View File

@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Callable
import torch
from sglang.srt.layers.dp_attention import DPPaddingMode
from sglang.srt.layers.dp_attention import DpPaddingMode, set_dp_buffer_len
from sglang.srt.model_executor.cuda_graph_runner import (
CUDA_GRAPH_CAPTURE_FAILED_MSG,
CudaGraphRunner,
@@ -105,30 +105,15 @@ class EAGLEDraftCudaGraphRunner:
self.global_num_tokens_for_logprob_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
)
self.gathered_buffer = torch.zeros(
(
self.max_num_token * self.dp_size,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)
else:
assert self.require_attn_tp_gather
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
self.global_num_tokens_for_logprob_gpu = torch.zeros(
(1,), dtype=torch.int32
)
self.gathered_buffer = torch.zeros(
(
self.max_num_token,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)
else:
self.global_num_tokens_gpu = None
self.global_num_tokens_for_logprob_gpu = None
self.gathered_buffer = None
# Capture
try:
@@ -193,7 +178,7 @@ class EAGLEDraftCudaGraphRunner:
)
)
global_num_tokens = self.global_num_tokens_gpu
gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size]
global_dp_buffer_len = num_tokens * self.dp_size
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
elif self.require_attn_tp_gather:
self.global_num_tokens_gpu.copy_(
@@ -211,11 +196,11 @@ class EAGLEDraftCudaGraphRunner:
)
)
global_num_tokens = self.global_num_tokens_gpu
gathered_buffer = self.gathered_buffer[:num_tokens]
global_dp_buffer_len = num_tokens
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
else:
global_num_tokens = None
gathered_buffer = None
global_dp_buffer_len = None
global_num_tokens_for_logprob = None
spec_info = EagleDraftInput(
@@ -239,8 +224,8 @@ class EAGLEDraftCudaGraphRunner:
return_logprob=False,
positions=positions,
global_num_tokens_gpu=global_num_tokens,
dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(),
gathered_buffer=gathered_buffer,
dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
global_dp_buffer_len=global_dp_buffer_len,
spec_algorithm=self.model_runner.spec_algorithm,
spec_info=spec_info,
capture_hidden_mode=(
@@ -258,6 +243,7 @@ class EAGLEDraftCudaGraphRunner:
def run_once():
# Clean intermediate result cache for DP attention
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
set_dp_buffer_len(global_dp_buffer_len, num_tokens)
# Backup two fields, which will be modified in-place in `draft_forward`.
output_cache_loc_backup = forward_batch.out_cache_loc

View File

@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Callable
import torch
from sglang.srt.layers.dp_attention import DPPaddingMode
from sglang.srt.layers.dp_attention import DpPaddingMode, set_dp_buffer_len
from sglang.srt.model_executor.cuda_graph_runner import (
CUDA_GRAPH_CAPTURE_FAILED_MSG,
CudaGraphRunner,
@@ -117,30 +117,15 @@ class EAGLEDraftExtendCudaGraphRunner:
self.global_num_tokens_for_logprob_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
)
self.gathered_buffer = torch.zeros(
(
self.max_num_token * self.dp_size,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)
else:
assert self.require_attn_tp_gather
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
self.global_num_tokens_for_logprob_gpu = torch.zeros(
(1,), dtype=torch.int32
)
self.gathered_buffer = torch.zeros(
(
self.max_num_token,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)
else:
self.global_num_tokens_gpu = None
self.global_num_tokens_for_logprob_gpu = None
self.gathered_buffer = None
if hasattr(
self.model_runner.model_config.hf_config, "draft_vocab_size"
@@ -222,7 +207,7 @@ class EAGLEDraftExtendCudaGraphRunner:
device=self.input_ids.device,
)
)
gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size]
global_dp_buffer_len = num_tokens * self.dp_size
elif self.require_attn_tp_gather:
self.global_num_tokens_gpu.copy_(
torch.tensor(
@@ -238,9 +223,9 @@ class EAGLEDraftExtendCudaGraphRunner:
device=self.input_ids.device,
)
)
gathered_buffer = self.gathered_buffer[:num_tokens]
global_dp_buffer_len = num_tokens
else:
gathered_buffer = None
global_dp_buffer_len = None
spec_info = EagleDraftInput(
hidden_states=hidden_states,
@@ -264,8 +249,8 @@ class EAGLEDraftExtendCudaGraphRunner:
positions=positions,
global_num_tokens_gpu=self.global_num_tokens_gpu,
global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu,
dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(),
gathered_buffer=gathered_buffer,
dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
global_dp_buffer_len=global_dp_buffer_len,
spec_algorithm=self.model_runner.spec_algorithm,
spec_info=spec_info,
capture_hidden_mode=CaptureHiddenMode.LAST,
@@ -288,6 +273,7 @@ class EAGLEDraftExtendCudaGraphRunner:
def run_once():
# Clean intermediate result cache for DP attention
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
set_dp_buffer_len(global_dp_buffer_len, num_tokens)
# Backup two fields, which will be modified in-place in `draft_forward`.
output_cache_loc_backup = forward_batch.out_cache_loc