Improve DP attention (#4390)
Co-authored-by: dhou-xai <dhou@x.ai> Co-authored-by: SangBin Cho <rkooo567@gmail.com>
This commit is contained in:
@@ -33,7 +33,7 @@ from sglang.srt.model_executor.forward_batch_info import (
|
||||
ForwardBatch,
|
||||
ForwardMode,
|
||||
)
|
||||
from sglang.srt.utils import is_hip
|
||||
from sglang.srt.utils import get_available_gpu_memory, is_hip
|
||||
|
||||
_is_hip = is_hip()
|
||||
|
||||
@@ -174,6 +174,7 @@ class CudaGraphRunner:
|
||||
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
||||
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
|
||||
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
|
||||
self.speculative_algorithm = model_runner.server_args.speculative_algorithm
|
||||
self.tp_size = model_runner.server_args.tp_size
|
||||
self.dp_size = model_runner.server_args.dp_size
|
||||
|
||||
@@ -236,7 +237,7 @@ class CudaGraphRunner:
|
||||
if self.enable_dp_attention:
|
||||
self.gathered_buffer = torch.zeros(
|
||||
(
|
||||
self.max_bs * self.dp_size,
|
||||
self.max_bs * self.dp_size * self.num_tokens_per_bs,
|
||||
self.model_runner.model_config.hidden_size,
|
||||
),
|
||||
dtype=self.model_runner.dtype,
|
||||
@@ -276,13 +277,12 @@ class CudaGraphRunner:
|
||||
|
||||
def can_run(self, forward_batch: ForwardBatch):
|
||||
if self.enable_dp_attention:
|
||||
min_num_tokens, max_num_tokens = min(
|
||||
forward_batch.global_num_tokens_cpu
|
||||
), max(forward_batch.global_num_tokens_cpu)
|
||||
total_global_tokens = sum(forward_batch.global_num_tokens_cpu)
|
||||
|
||||
is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
|
||||
(min_num_tokens == max_num_tokens and max_num_tokens in self.graphs)
|
||||
total_global_tokens in self.graphs
|
||||
if self.disable_padding
|
||||
else max_num_tokens <= self.max_bs
|
||||
else total_global_tokens <= self.max_bs
|
||||
)
|
||||
else:
|
||||
is_bs_supported = (
|
||||
@@ -304,6 +304,9 @@ class CudaGraphRunner:
|
||||
def capture(self):
|
||||
with graph_capture() as graph_capture_context:
|
||||
self.stream = graph_capture_context.stream
|
||||
avail_mem = get_available_gpu_memory(
|
||||
self.model_runner.device, self.model_runner.gpu_id, empty_cache=False
|
||||
)
|
||||
# Reverse the order to enable better memory sharing across cuda graphs.
|
||||
capture_range = (
|
||||
tqdm.tqdm(list(reversed(self.capture_bs)))
|
||||
@@ -311,6 +314,16 @@ class CudaGraphRunner:
|
||||
else reversed(self.capture_bs)
|
||||
)
|
||||
for bs in capture_range:
|
||||
if get_tensor_model_parallel_rank() == 0:
|
||||
avail_mem = get_available_gpu_memory(
|
||||
self.model_runner.device,
|
||||
self.model_runner.gpu_id,
|
||||
empty_cache=False,
|
||||
)
|
||||
capture_range.set_description(
|
||||
f"Capturing batches ({avail_mem=:.2f} GB)"
|
||||
)
|
||||
|
||||
with patch_model(
|
||||
self.model_runner.model,
|
||||
bs in self.compile_bs,
|
||||
@@ -345,8 +358,18 @@ class CudaGraphRunner:
|
||||
mrope_positions = self.mrope_positions[:, :bs]
|
||||
|
||||
if self.enable_dp_attention:
|
||||
global_num_tokens = [bs] * self.tp_size
|
||||
gathered_buffer = self.gathered_buffer[: bs * self.tp_size]
|
||||
self.global_num_tokens_gpu.copy_(
|
||||
torch.tensor(
|
||||
[
|
||||
num_tokens // self.dp_size + (i < bs % self.dp_size)
|
||||
for i in range(self.dp_size)
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device=input_ids.device,
|
||||
)
|
||||
)
|
||||
global_num_tokens = self.global_num_tokens_gpu
|
||||
gathered_buffer = self.gathered_buffer[:num_tokens]
|
||||
else:
|
||||
global_num_tokens = None
|
||||
gathered_buffer = None
|
||||
@@ -371,7 +394,7 @@ class CudaGraphRunner:
|
||||
encoder_lens=encoder_lens,
|
||||
return_logprob=False,
|
||||
positions=positions,
|
||||
global_num_tokens_cpu=global_num_tokens,
|
||||
global_num_tokens_gpu=global_num_tokens,
|
||||
gathered_buffer=gathered_buffer,
|
||||
mrope_positions=mrope_positions,
|
||||
spec_algorithm=self.model_runner.spec_algorithm,
|
||||
@@ -392,6 +415,9 @@ class CudaGraphRunner:
|
||||
|
||||
# Run and capture
|
||||
def run_once():
|
||||
# Clean intermediate result cache for DP attention
|
||||
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
|
||||
|
||||
logits_output = forward(input_ids, forward_batch.positions, forward_batch)
|
||||
return logits_output.next_token_logits, logits_output.hidden_states
|
||||
|
||||
@@ -426,7 +452,7 @@ class CudaGraphRunner:
|
||||
self.capture_hidden_mode = hidden_mode_from_spec_info
|
||||
self.capture()
|
||||
|
||||
def replay(self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False):
|
||||
def replay_prepare(self, forward_batch: ForwardBatch):
|
||||
self.recapture_if_needed(forward_batch)
|
||||
|
||||
raw_bs = forward_batch.batch_size
|
||||
@@ -435,7 +461,7 @@ class CudaGraphRunner:
|
||||
# Pad
|
||||
if self.enable_dp_attention:
|
||||
index = bisect.bisect_left(
|
||||
self.capture_bs, max(forward_batch.global_num_tokens_cpu)
|
||||
self.capture_bs, sum(forward_batch.global_num_tokens_cpu)
|
||||
)
|
||||
else:
|
||||
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
||||
@@ -459,6 +485,8 @@ class CudaGraphRunner:
|
||||
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
||||
if forward_batch.mrope_positions is not None:
|
||||
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
|
||||
if self.enable_dp_attention:
|
||||
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
|
||||
|
||||
if hasattr(forward_batch.spec_info, "hidden_states"):
|
||||
self.hidden_states[:raw_num_token] = forward_batch.spec_info.hidden_states
|
||||
@@ -475,14 +503,29 @@ class CudaGraphRunner:
|
||||
seq_lens_cpu=self.seq_lens_cpu,
|
||||
)
|
||||
|
||||
# Store fields
|
||||
self.raw_bs = raw_bs
|
||||
self.raw_num_token = raw_num_token
|
||||
self.bs = bs
|
||||
|
||||
def replay(self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False):
|
||||
if not skip_attn_backend_init:
|
||||
self.replay_prepare(forward_batch)
|
||||
else:
|
||||
# In speculative decoding, these two fields are still needed.
|
||||
self.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids)
|
||||
self.positions[: self.raw_num_token].copy_(forward_batch.positions)
|
||||
|
||||
# Replay
|
||||
self.graphs[bs].replay()
|
||||
next_token_logits, hidden_states = self.output_buffers[bs]
|
||||
self.graphs[self.bs].replay()
|
||||
next_token_logits, hidden_states = self.output_buffers[self.bs]
|
||||
|
||||
logits_output = LogitsProcessorOutput(
|
||||
next_token_logits=next_token_logits[:raw_num_token],
|
||||
next_token_logits=next_token_logits[: self.raw_num_token],
|
||||
hidden_states=(
|
||||
hidden_states[:raw_num_token] if hidden_states is not None else None
|
||||
hidden_states[: self.raw_num_token]
|
||||
if hidden_states is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
return logits_output
|
||||
|
||||
@@ -38,7 +38,7 @@ import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
||||
from sglang.srt.utils import get_compiler_backend, next_power_of_2
|
||||
from sglang.srt.utils import get_compiler_backend
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
@@ -263,15 +263,24 @@ class ForwardBatch:
|
||||
extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu,
|
||||
)
|
||||
|
||||
# For DP attention
|
||||
if batch.global_num_tokens is not None:
|
||||
ret.global_num_tokens_cpu = batch.global_num_tokens
|
||||
max_len = max(ret.global_num_tokens_cpu)
|
||||
ret.global_num_tokens_gpu = torch.tensor(
|
||||
batch.global_num_tokens, dtype=torch.int64
|
||||
).to(device, non_blocking=True)
|
||||
|
||||
ret.global_num_tokens_for_logprob_cpu = batch.global_num_tokens_for_logprob
|
||||
ret.global_num_tokens_for_logprob_gpu = torch.tensor(
|
||||
batch.global_num_tokens_for_logprob, dtype=torch.int64
|
||||
).to(device, non_blocking=True)
|
||||
|
||||
sum_len = sum(batch.global_num_tokens)
|
||||
ret.gathered_buffer = torch.zeros(
|
||||
(max_len * model_runner.tp_size, model_runner.model_config.hidden_size),
|
||||
(sum_len, model_runner.model_config.hidden_size),
|
||||
dtype=model_runner.dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if ret.forward_mode.is_idle():
|
||||
ret.positions = torch.empty((0,), device=device)
|
||||
return ret
|
||||
|
||||
Reference in New Issue
Block a user