Improve DP attention (#4390)

Co-authored-by: dhou-xai <dhou@x.ai>
Co-authored-by: SangBin Cho <rkooo567@gmail.com>
This commit is contained in:
Lianmin Zheng
2025-03-13 08:23:56 -07:00
committed by GitHub
parent f141298a3c
commit 8e66fbecee
9 changed files with 345 additions and 226 deletions

View File

@@ -33,7 +33,7 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch,
ForwardMode,
)
from sglang.srt.utils import is_hip
from sglang.srt.utils import get_available_gpu_memory, is_hip
_is_hip = is_hip()
@@ -174,6 +174,7 @@ class CudaGraphRunner:
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
self.speculative_algorithm = model_runner.server_args.speculative_algorithm
self.tp_size = model_runner.server_args.tp_size
self.dp_size = model_runner.server_args.dp_size
@@ -236,7 +237,7 @@ class CudaGraphRunner:
if self.enable_dp_attention:
self.gathered_buffer = torch.zeros(
(
self.max_bs * self.dp_size,
self.max_bs * self.dp_size * self.num_tokens_per_bs,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
@@ -276,13 +277,12 @@ class CudaGraphRunner:
def can_run(self, forward_batch: ForwardBatch):
if self.enable_dp_attention:
min_num_tokens, max_num_tokens = min(
forward_batch.global_num_tokens_cpu
), max(forward_batch.global_num_tokens_cpu)
total_global_tokens = sum(forward_batch.global_num_tokens_cpu)
is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
(min_num_tokens == max_num_tokens and max_num_tokens in self.graphs)
total_global_tokens in self.graphs
if self.disable_padding
else max_num_tokens <= self.max_bs
else total_global_tokens <= self.max_bs
)
else:
is_bs_supported = (
@@ -304,6 +304,9 @@ class CudaGraphRunner:
def capture(self):
with graph_capture() as graph_capture_context:
self.stream = graph_capture_context.stream
avail_mem = get_available_gpu_memory(
self.model_runner.device, self.model_runner.gpu_id, empty_cache=False
)
# Reverse the order to enable better memory sharing across cuda graphs.
capture_range = (
tqdm.tqdm(list(reversed(self.capture_bs)))
@@ -311,6 +314,16 @@ class CudaGraphRunner:
else reversed(self.capture_bs)
)
for bs in capture_range:
if get_tensor_model_parallel_rank() == 0:
avail_mem = get_available_gpu_memory(
self.model_runner.device,
self.model_runner.gpu_id,
empty_cache=False,
)
capture_range.set_description(
f"Capturing batches ({avail_mem=:.2f} GB)"
)
with patch_model(
self.model_runner.model,
bs in self.compile_bs,
@@ -345,8 +358,18 @@ class CudaGraphRunner:
mrope_positions = self.mrope_positions[:, :bs]
if self.enable_dp_attention:
global_num_tokens = [bs] * self.tp_size
gathered_buffer = self.gathered_buffer[: bs * self.tp_size]
self.global_num_tokens_gpu.copy_(
torch.tensor(
[
num_tokens // self.dp_size + (i < bs % self.dp_size)
for i in range(self.dp_size)
],
dtype=torch.int32,
device=input_ids.device,
)
)
global_num_tokens = self.global_num_tokens_gpu
gathered_buffer = self.gathered_buffer[:num_tokens]
else:
global_num_tokens = None
gathered_buffer = None
@@ -371,7 +394,7 @@ class CudaGraphRunner:
encoder_lens=encoder_lens,
return_logprob=False,
positions=positions,
global_num_tokens_cpu=global_num_tokens,
global_num_tokens_gpu=global_num_tokens,
gathered_buffer=gathered_buffer,
mrope_positions=mrope_positions,
spec_algorithm=self.model_runner.spec_algorithm,
@@ -392,6 +415,9 @@ class CudaGraphRunner:
# Run and capture
def run_once():
# Clean intermediate result cache for DP attention
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
logits_output = forward(input_ids, forward_batch.positions, forward_batch)
return logits_output.next_token_logits, logits_output.hidden_states
@@ -426,7 +452,7 @@ class CudaGraphRunner:
self.capture_hidden_mode = hidden_mode_from_spec_info
self.capture()
def replay(self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False):
def replay_prepare(self, forward_batch: ForwardBatch):
self.recapture_if_needed(forward_batch)
raw_bs = forward_batch.batch_size
@@ -435,7 +461,7 @@ class CudaGraphRunner:
# Pad
if self.enable_dp_attention:
index = bisect.bisect_left(
self.capture_bs, max(forward_batch.global_num_tokens_cpu)
self.capture_bs, sum(forward_batch.global_num_tokens_cpu)
)
else:
index = bisect.bisect_left(self.capture_bs, raw_bs)
@@ -459,6 +485,8 @@ class CudaGraphRunner:
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
if forward_batch.mrope_positions is not None:
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
if self.enable_dp_attention:
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
if hasattr(forward_batch.spec_info, "hidden_states"):
self.hidden_states[:raw_num_token] = forward_batch.spec_info.hidden_states
@@ -475,14 +503,29 @@ class CudaGraphRunner:
seq_lens_cpu=self.seq_lens_cpu,
)
# Store fields
self.raw_bs = raw_bs
self.raw_num_token = raw_num_token
self.bs = bs
def replay(self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False):
if not skip_attn_backend_init:
self.replay_prepare(forward_batch)
else:
# In speculative decoding, these two fields are still needed.
self.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids)
self.positions[: self.raw_num_token].copy_(forward_batch.positions)
# Replay
self.graphs[bs].replay()
next_token_logits, hidden_states = self.output_buffers[bs]
self.graphs[self.bs].replay()
next_token_logits, hidden_states = self.output_buffers[self.bs]
logits_output = LogitsProcessorOutput(
next_token_logits=next_token_logits[:raw_num_token],
next_token_logits=next_token_logits[: self.raw_num_token],
hidden_states=(
hidden_states[:raw_num_token] if hidden_states is not None else None
hidden_states[: self.raw_num_token]
if hidden_states is not None
else None
),
)
return logits_output

View File

@@ -38,7 +38,7 @@ import triton
import triton.language as tl
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
from sglang.srt.utils import get_compiler_backend, next_power_of_2
from sglang.srt.utils import get_compiler_backend
if TYPE_CHECKING:
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
@@ -263,15 +263,24 @@ class ForwardBatch:
extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu,
)
# For DP attention
if batch.global_num_tokens is not None:
ret.global_num_tokens_cpu = batch.global_num_tokens
max_len = max(ret.global_num_tokens_cpu)
ret.global_num_tokens_gpu = torch.tensor(
batch.global_num_tokens, dtype=torch.int64
).to(device, non_blocking=True)
ret.global_num_tokens_for_logprob_cpu = batch.global_num_tokens_for_logprob
ret.global_num_tokens_for_logprob_gpu = torch.tensor(
batch.global_num_tokens_for_logprob, dtype=torch.int64
).to(device, non_blocking=True)
sum_len = sum(batch.global_num_tokens)
ret.gathered_buffer = torch.zeros(
(max_len * model_runner.tp_size, model_runner.model_config.hidden_size),
(sum_len, model_runner.model_config.hidden_size),
dtype=model_runner.dtype,
device=device,
)
if ret.forward_mode.is_idle():
ret.positions = torch.empty((0,), device=device)
return ret