Support cuda graph for DP attention (#2061)

This commit is contained in:
Ke Bao
2024-11-18 08:29:20 +08:00
committed by GitHub
parent 11f881d173
commit 62832bb272
9 changed files with 88 additions and 26 deletions

View File

@@ -111,6 +111,8 @@ class CudaGraphRunner:
self.use_torch_compile = model_runner.server_args.enable_torch_compile
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder
self.enable_dp_attention = self.model_runner.server_args.enable_dp_attention
self.tp_size = self.model_runner.tp_size
# Batch sizes to capture
if model_runner.server_args.disable_cuda_graph_padding:
@@ -165,6 +167,16 @@ class CudaGraphRunner:
else:
self.encoder_lens = None
if self.enable_dp_attention:
self.global_num_tokens = [0] * self.tp_size
self.gathered_buffer = torch.zeros(
(
self.max_bs * self.tp_size,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)
# Capture
try:
with self.model_capture_mode():
@@ -190,11 +202,21 @@ class CudaGraphRunner:
self.model_runner.model.capture_mode = False
def can_run(self, forward_batch: ForwardBatch):
is_bs_supported = (
forward_batch.batch_size in self.graphs
if self.disable_padding
else forward_batch.batch_size <= self.max_bs
)
if self.enable_dp_attention:
min_num_tokens, max_num_tokens = min(forward_batch.global_num_tokens), max(
forward_batch.global_num_tokens
)
is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
(min_num_tokens == max_num_tokens and max_num_tokens in self.graphs)
if self.disable_padding
else max_num_tokens <= self.max_bs
)
else:
is_bs_supported = (
forward_batch.batch_size in self.graphs
if self.disable_padding
else forward_batch.batch_size <= self.max_bs
)
# NOTE: cuda graph cannot handle mixed batch (encoder_len = 0)
# If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph
@@ -239,6 +261,13 @@ class CudaGraphRunner:
seq_lens_sum = seq_lens.sum().item()
mrope_positions = self.mrope_positions[:, :bs]
if self.enable_dp_attention:
self.global_num_tokens[:] = [bs] * self.tp_size
gathered_buffer = self.gathered_buffer[: bs * self.tp_size]
else:
self.global_num_tokens = None
gathered_buffer = None
# Attention backend
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
bs,
@@ -265,6 +294,8 @@ class CudaGraphRunner:
top_logprobs_nums=[0] * bs,
positions=clamp_position(seq_lens),
mrope_positions=mrope_positions,
global_num_tokens=self.global_num_tokens,
gathered_buffer=gathered_buffer,
)
logits_output = forward(input_ids, forward_batch.positions, forward_batch)
return logits_output.next_token_logits
@@ -295,7 +326,12 @@ class CudaGraphRunner:
raw_bs = forward_batch.batch_size
# Pad
index = bisect.bisect_left(self.capture_bs, raw_bs)
if self.enable_dp_attention:
index = bisect.bisect_left(
self.capture_bs, max(forward_batch.global_num_tokens)
)
else:
index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index]
if bs != raw_bs:
self.seq_lens.fill_(1)
@@ -310,6 +346,8 @@ class CudaGraphRunner:
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
if forward_batch.mrope_positions is not None:
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
if self.enable_dp_attention:
self.global_num_tokens[:] = [bs] * self.tp_size
# Attention backend
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(

View File

@@ -138,6 +138,7 @@ class ForwardBatch:
# For DP attention
global_num_tokens: Optional[List[int]] = None
gathered_buffer: Optional[torch.Tensor] = None
can_run_dp_cuda_graph: bool = False
def compute_mrope_positions(
self, model_runner: ModelRunner, batch: ModelWorkerBatch
@@ -221,6 +222,7 @@ class ForwardBatch:
return_logprob=batch.return_logprob,
top_logprobs_nums=batch.top_logprobs_nums,
global_num_tokens=batch.global_num_tokens,
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
lora_paths=batch.lora_paths,
sampling_info=batch.sampling_info,
)

View File

@@ -592,6 +592,9 @@ class ModelRunner:
)
def forward_idle(self, forward_batch: ForwardBatch):
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
return self.cuda_graph_runner.replay(forward_batch)
return self.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch
)