[DP Attention] Refactor: adding some utility functions (#9136)
This commit is contained in:
@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Callable
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.dp_attention import DPPaddingMode
|
||||
from sglang.srt.layers.dp_attention import DpPaddingMode, set_dp_buffer_len
|
||||
from sglang.srt.model_executor.cuda_graph_runner import (
|
||||
CUDA_GRAPH_CAPTURE_FAILED_MSG,
|
||||
CudaGraphRunner,
|
||||
@@ -105,30 +105,15 @@ class EAGLEDraftCudaGraphRunner:
|
||||
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
||||
(self.dp_size,), dtype=torch.int32
|
||||
)
|
||||
self.gathered_buffer = torch.zeros(
|
||||
(
|
||||
self.max_num_token * self.dp_size,
|
||||
self.model_runner.model_config.hidden_size,
|
||||
),
|
||||
dtype=self.model_runner.dtype,
|
||||
)
|
||||
else:
|
||||
assert self.require_attn_tp_gather
|
||||
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
|
||||
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
||||
(1,), dtype=torch.int32
|
||||
)
|
||||
self.gathered_buffer = torch.zeros(
|
||||
(
|
||||
self.max_num_token,
|
||||
self.model_runner.model_config.hidden_size,
|
||||
),
|
||||
dtype=self.model_runner.dtype,
|
||||
)
|
||||
else:
|
||||
self.global_num_tokens_gpu = None
|
||||
self.global_num_tokens_for_logprob_gpu = None
|
||||
self.gathered_buffer = None
|
||||
|
||||
# Capture
|
||||
try:
|
||||
@@ -193,7 +178,7 @@ class EAGLEDraftCudaGraphRunner:
|
||||
)
|
||||
)
|
||||
global_num_tokens = self.global_num_tokens_gpu
|
||||
gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size]
|
||||
global_dp_buffer_len = num_tokens * self.dp_size
|
||||
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
|
||||
elif self.require_attn_tp_gather:
|
||||
self.global_num_tokens_gpu.copy_(
|
||||
@@ -211,11 +196,11 @@ class EAGLEDraftCudaGraphRunner:
|
||||
)
|
||||
)
|
||||
global_num_tokens = self.global_num_tokens_gpu
|
||||
gathered_buffer = self.gathered_buffer[:num_tokens]
|
||||
global_dp_buffer_len = num_tokens
|
||||
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
|
||||
else:
|
||||
global_num_tokens = None
|
||||
gathered_buffer = None
|
||||
global_dp_buffer_len = None
|
||||
global_num_tokens_for_logprob = None
|
||||
|
||||
spec_info = EagleDraftInput(
|
||||
@@ -239,8 +224,8 @@ class EAGLEDraftCudaGraphRunner:
|
||||
return_logprob=False,
|
||||
positions=positions,
|
||||
global_num_tokens_gpu=global_num_tokens,
|
||||
dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(),
|
||||
gathered_buffer=gathered_buffer,
|
||||
dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
|
||||
global_dp_buffer_len=global_dp_buffer_len,
|
||||
spec_algorithm=self.model_runner.spec_algorithm,
|
||||
spec_info=spec_info,
|
||||
capture_hidden_mode=(
|
||||
@@ -258,6 +243,7 @@ class EAGLEDraftCudaGraphRunner:
|
||||
def run_once():
|
||||
# Clean intermediate result cache for DP attention
|
||||
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
|
||||
set_dp_buffer_len(global_dp_buffer_len, num_tokens)
|
||||
|
||||
# Backup two fields, which will be modified in-place in `draft_forward`.
|
||||
output_cache_loc_backup = forward_batch.out_cache_loc
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Callable
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.dp_attention import DPPaddingMode
|
||||
from sglang.srt.layers.dp_attention import DpPaddingMode, set_dp_buffer_len
|
||||
from sglang.srt.model_executor.cuda_graph_runner import (
|
||||
CUDA_GRAPH_CAPTURE_FAILED_MSG,
|
||||
CudaGraphRunner,
|
||||
@@ -117,30 +117,15 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
||||
(self.dp_size,), dtype=torch.int32
|
||||
)
|
||||
self.gathered_buffer = torch.zeros(
|
||||
(
|
||||
self.max_num_token * self.dp_size,
|
||||
self.model_runner.model_config.hidden_size,
|
||||
),
|
||||
dtype=self.model_runner.dtype,
|
||||
)
|
||||
else:
|
||||
assert self.require_attn_tp_gather
|
||||
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
|
||||
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
||||
(1,), dtype=torch.int32
|
||||
)
|
||||
self.gathered_buffer = torch.zeros(
|
||||
(
|
||||
self.max_num_token,
|
||||
self.model_runner.model_config.hidden_size,
|
||||
),
|
||||
dtype=self.model_runner.dtype,
|
||||
)
|
||||
else:
|
||||
self.global_num_tokens_gpu = None
|
||||
self.global_num_tokens_for_logprob_gpu = None
|
||||
self.gathered_buffer = None
|
||||
|
||||
if hasattr(
|
||||
self.model_runner.model_config.hf_config, "draft_vocab_size"
|
||||
@@ -222,7 +207,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
device=self.input_ids.device,
|
||||
)
|
||||
)
|
||||
gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size]
|
||||
global_dp_buffer_len = num_tokens * self.dp_size
|
||||
elif self.require_attn_tp_gather:
|
||||
self.global_num_tokens_gpu.copy_(
|
||||
torch.tensor(
|
||||
@@ -238,9 +223,9 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
device=self.input_ids.device,
|
||||
)
|
||||
)
|
||||
gathered_buffer = self.gathered_buffer[:num_tokens]
|
||||
global_dp_buffer_len = num_tokens
|
||||
else:
|
||||
gathered_buffer = None
|
||||
global_dp_buffer_len = None
|
||||
|
||||
spec_info = EagleDraftInput(
|
||||
hidden_states=hidden_states,
|
||||
@@ -264,8 +249,8 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
positions=positions,
|
||||
global_num_tokens_gpu=self.global_num_tokens_gpu,
|
||||
global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu,
|
||||
dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(),
|
||||
gathered_buffer=gathered_buffer,
|
||||
dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
|
||||
global_dp_buffer_len=global_dp_buffer_len,
|
||||
spec_algorithm=self.model_runner.spec_algorithm,
|
||||
spec_info=spec_info,
|
||||
capture_hidden_mode=CaptureHiddenMode.LAST,
|
||||
@@ -288,6 +273,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
def run_once():
|
||||
# Clean intermediate result cache for DP attention
|
||||
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
|
||||
set_dp_buffer_len(global_dp_buffer_len, num_tokens)
|
||||
|
||||
# Backup two fields, which will be modified in-place in `draft_forward`.
|
||||
output_cache_loc_backup = forward_batch.out_cache_loc
|
||||
|
||||
Reference in New Issue
Block a user