Support cuda graph for DP attention (#2061)
This commit is contained in:
@@ -111,6 +111,8 @@ class CudaGraphRunner:
|
||||
self.use_torch_compile = model_runner.server_args.enable_torch_compile
|
||||
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
||||
self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder
|
||||
self.enable_dp_attention = self.model_runner.server_args.enable_dp_attention
|
||||
self.tp_size = self.model_runner.tp_size
|
||||
|
||||
# Batch sizes to capture
|
||||
if model_runner.server_args.disable_cuda_graph_padding:
|
||||
@@ -165,6 +167,16 @@ class CudaGraphRunner:
|
||||
else:
|
||||
self.encoder_lens = None
|
||||
|
||||
if self.enable_dp_attention:
|
||||
self.global_num_tokens = [0] * self.tp_size
|
||||
self.gathered_buffer = torch.zeros(
|
||||
(
|
||||
self.max_bs * self.tp_size,
|
||||
self.model_runner.model_config.hidden_size,
|
||||
),
|
||||
dtype=self.model_runner.dtype,
|
||||
)
|
||||
|
||||
# Capture
|
||||
try:
|
||||
with self.model_capture_mode():
|
||||
@@ -190,11 +202,21 @@ class CudaGraphRunner:
|
||||
self.model_runner.model.capture_mode = False
|
||||
|
||||
def can_run(self, forward_batch: ForwardBatch):
|
||||
is_bs_supported = (
|
||||
forward_batch.batch_size in self.graphs
|
||||
if self.disable_padding
|
||||
else forward_batch.batch_size <= self.max_bs
|
||||
)
|
||||
if self.enable_dp_attention:
|
||||
min_num_tokens, max_num_tokens = min(forward_batch.global_num_tokens), max(
|
||||
forward_batch.global_num_tokens
|
||||
)
|
||||
is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
|
||||
(min_num_tokens == max_num_tokens and max_num_tokens in self.graphs)
|
||||
if self.disable_padding
|
||||
else max_num_tokens <= self.max_bs
|
||||
)
|
||||
else:
|
||||
is_bs_supported = (
|
||||
forward_batch.batch_size in self.graphs
|
||||
if self.disable_padding
|
||||
else forward_batch.batch_size <= self.max_bs
|
||||
)
|
||||
|
||||
# NOTE: cuda graph cannot handle mixed batch (encoder_len = 0)
|
||||
# If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph
|
||||
@@ -239,6 +261,13 @@ class CudaGraphRunner:
|
||||
seq_lens_sum = seq_lens.sum().item()
|
||||
mrope_positions = self.mrope_positions[:, :bs]
|
||||
|
||||
if self.enable_dp_attention:
|
||||
self.global_num_tokens[:] = [bs] * self.tp_size
|
||||
gathered_buffer = self.gathered_buffer[: bs * self.tp_size]
|
||||
else:
|
||||
self.global_num_tokens = None
|
||||
gathered_buffer = None
|
||||
|
||||
# Attention backend
|
||||
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
|
||||
bs,
|
||||
@@ -265,6 +294,8 @@ class CudaGraphRunner:
|
||||
top_logprobs_nums=[0] * bs,
|
||||
positions=clamp_position(seq_lens),
|
||||
mrope_positions=mrope_positions,
|
||||
global_num_tokens=self.global_num_tokens,
|
||||
gathered_buffer=gathered_buffer,
|
||||
)
|
||||
logits_output = forward(input_ids, forward_batch.positions, forward_batch)
|
||||
return logits_output.next_token_logits
|
||||
@@ -295,7 +326,12 @@ class CudaGraphRunner:
|
||||
raw_bs = forward_batch.batch_size
|
||||
|
||||
# Pad
|
||||
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
||||
if self.enable_dp_attention:
|
||||
index = bisect.bisect_left(
|
||||
self.capture_bs, max(forward_batch.global_num_tokens)
|
||||
)
|
||||
else:
|
||||
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
||||
bs = self.capture_bs[index]
|
||||
if bs != raw_bs:
|
||||
self.seq_lens.fill_(1)
|
||||
@@ -310,6 +346,8 @@ class CudaGraphRunner:
|
||||
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
||||
if forward_batch.mrope_positions is not None:
|
||||
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
|
||||
if self.enable_dp_attention:
|
||||
self.global_num_tokens[:] = [bs] * self.tp_size
|
||||
|
||||
# Attention backend
|
||||
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
||||
|
||||
@@ -138,6 +138,7 @@ class ForwardBatch:
|
||||
# For DP attention
|
||||
global_num_tokens: Optional[List[int]] = None
|
||||
gathered_buffer: Optional[torch.Tensor] = None
|
||||
can_run_dp_cuda_graph: bool = False
|
||||
|
||||
def compute_mrope_positions(
|
||||
self, model_runner: ModelRunner, batch: ModelWorkerBatch
|
||||
@@ -221,6 +222,7 @@ class ForwardBatch:
|
||||
return_logprob=batch.return_logprob,
|
||||
top_logprobs_nums=batch.top_logprobs_nums,
|
||||
global_num_tokens=batch.global_num_tokens,
|
||||
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
|
||||
lora_paths=batch.lora_paths,
|
||||
sampling_info=batch.sampling_info,
|
||||
)
|
||||
|
||||
@@ -592,6 +592,9 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
def forward_idle(self, forward_batch: ForwardBatch):
|
||||
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
|
||||
return self.cuda_graph_runner.replay(forward_batch)
|
||||
|
||||
return self.model.forward(
|
||||
forward_batch.input_ids, forward_batch.positions, forward_batch
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user