[DP Attention] Refactor: adding some utility functions (#9136)
This commit is contained in:
@@ -34,9 +34,10 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
|
||||
)
|
||||
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
DPPaddingMode,
|
||||
DpPaddingMode,
|
||||
get_attention_tp_rank,
|
||||
get_attention_tp_size,
|
||||
set_dp_buffer_len,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.torchao_utils import save_gemlite_cache
|
||||
@@ -349,30 +350,15 @@ class CudaGraphRunner:
|
||||
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
||||
(self.dp_size,), dtype=torch.int32
|
||||
)
|
||||
self.gathered_buffer = torch.zeros(
|
||||
(
|
||||
self.max_num_token * self.dp_size,
|
||||
self.model_runner.model_config.hidden_size,
|
||||
),
|
||||
dtype=self.model_runner.dtype,
|
||||
)
|
||||
else:
|
||||
assert self.require_attn_tp_gather
|
||||
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
|
||||
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
||||
(1,), dtype=torch.int32
|
||||
)
|
||||
self.gathered_buffer = torch.zeros(
|
||||
(
|
||||
self.max_num_token,
|
||||
self.model_runner.model_config.hidden_size,
|
||||
),
|
||||
dtype=self.model_runner.dtype,
|
||||
)
|
||||
else:
|
||||
self.global_num_tokens_gpu = None
|
||||
self.global_num_tokens_for_logprob_gpu = None
|
||||
self.gathered_buffer = None
|
||||
|
||||
self.custom_mask = torch.ones(
|
||||
(
|
||||
@@ -556,7 +542,7 @@ class CudaGraphRunner:
|
||||
device=input_ids.device,
|
||||
)
|
||||
)
|
||||
gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size]
|
||||
global_dp_buffer_len = num_tokens * self.dp_size
|
||||
elif self.require_attn_tp_gather:
|
||||
self.global_num_tokens_gpu.copy_(
|
||||
torch.tensor(
|
||||
@@ -572,9 +558,9 @@ class CudaGraphRunner:
|
||||
device=input_ids.device,
|
||||
)
|
||||
)
|
||||
gathered_buffer = self.gathered_buffer[:num_tokens]
|
||||
global_dp_buffer_len = num_tokens
|
||||
else:
|
||||
gathered_buffer = None
|
||||
global_dp_buffer_len = None
|
||||
|
||||
spec_info = self.get_spec_info(num_tokens)
|
||||
if self.capture_hidden_mode != CaptureHiddenMode.FULL:
|
||||
@@ -607,8 +593,8 @@ class CudaGraphRunner:
|
||||
positions=positions,
|
||||
global_num_tokens_gpu=self.global_num_tokens_gpu,
|
||||
global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu,
|
||||
dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(),
|
||||
gathered_buffer=gathered_buffer,
|
||||
dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
|
||||
global_dp_buffer_len=global_dp_buffer_len,
|
||||
mrope_positions=mrope_positions,
|
||||
spec_algorithm=self.model_runner.spec_algorithm,
|
||||
spec_info=spec_info,
|
||||
@@ -637,6 +623,7 @@ class CudaGraphRunner:
|
||||
def run_once():
|
||||
# Clean intermediate result cache for DP attention
|
||||
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
|
||||
set_dp_buffer_len(global_dp_buffer_len, num_tokens)
|
||||
|
||||
kwargs = {}
|
||||
if (
|
||||
|
||||
@@ -40,9 +40,10 @@ import triton.language as tl
|
||||
|
||||
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
DPPaddingMode,
|
||||
DpPaddingMode,
|
||||
get_attention_dp_rank,
|
||||
get_attention_tp_size,
|
||||
set_dp_buffer_len,
|
||||
)
|
||||
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
||||
from sglang.srt.utils import (
|
||||
@@ -274,13 +275,13 @@ class ForwardBatch:
|
||||
global_num_tokens_for_logprob_cpu: Optional[List[int]] = None
|
||||
global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None
|
||||
# The padding mode for DP attention
|
||||
dp_padding_mode: Optional[DPPaddingMode] = None
|
||||
dp_padding_mode: Optional[DpPaddingMode] = None
|
||||
# for extend, local start pos and num tokens is different in logits processor
|
||||
# this will be computed in get_dp_local_info
|
||||
# this will be recomputed in LogitsMetadata.from_forward_batch
|
||||
dp_local_start_pos: Optional[torch.Tensor] = None # cached info at runtime
|
||||
dp_local_num_tokens: Optional[torch.Tensor] = None # cached info at runtime
|
||||
gathered_buffer: Optional[torch.Tensor] = None
|
||||
global_dp_buffer_len: Optional[int] = None
|
||||
is_extend_in_batch: bool = False
|
||||
can_run_dp_cuda_graph: bool = False
|
||||
global_forward_mode: Optional[ForwardMode] = None
|
||||
@@ -628,7 +629,7 @@ class ForwardBatch:
|
||||
(global_num_tokens[i] - 1) // attn_tp_size + 1
|
||||
) * attn_tp_size
|
||||
|
||||
dp_padding_mode = DPPaddingMode.get_dp_padding_mode(global_num_tokens)
|
||||
dp_padding_mode = DpPaddingMode.get_dp_padding_mode(global_num_tokens)
|
||||
self.dp_padding_mode = dp_padding_mode
|
||||
|
||||
if dp_padding_mode.is_max_len():
|
||||
@@ -642,17 +643,14 @@ class ForwardBatch:
|
||||
else:
|
||||
buffer_len = sum(global_num_tokens)
|
||||
|
||||
self.gathered_buffer = torch.zeros(
|
||||
(buffer_len, model_runner.model_config.hidden_size),
|
||||
dtype=model_runner.dtype,
|
||||
device=model_runner.device,
|
||||
)
|
||||
|
||||
if len(global_num_tokens) > 1:
|
||||
num_tokens = global_num_tokens[get_attention_dp_rank()]
|
||||
else:
|
||||
num_tokens = global_num_tokens[0]
|
||||
|
||||
self.global_dp_buffer_len = buffer_len
|
||||
set_dp_buffer_len(buffer_len, num_tokens)
|
||||
|
||||
bs = self.batch_size
|
||||
|
||||
if self.forward_mode.is_decode():
|
||||
|
||||
@@ -603,12 +603,8 @@ class ModelRunner:
|
||||
duplicate_tp_group=self.server_args.enable_pdmux,
|
||||
)
|
||||
initialize_dp_attention(
|
||||
enable_dp_attention=self.server_args.enable_dp_attention,
|
||||
tp_rank=self.tp_rank,
|
||||
tp_size=self.tp_size,
|
||||
dp_size=self.server_args.dp_size,
|
||||
moe_dense_tp_size=self.server_args.moe_dense_tp_size,
|
||||
pp_size=self.server_args.pp_size,
|
||||
server_args=self.server_args,
|
||||
model_config=self.model_config,
|
||||
)
|
||||
|
||||
min_per_gpu_memory = get_available_gpu_memory(
|
||||
|
||||
Reference in New Issue
Block a user