Support cuda graph for DP attention (#2061)

This commit is contained in:
Ke Bao
2024-11-18 08:29:20 +08:00
committed by GitHub
parent 11f881d173
commit 62832bb272
9 changed files with 88 additions and 26 deletions

View File

@@ -455,6 +455,7 @@ class ScheduleBatch:
# For DP attention
global_num_tokens: Optional[List[int]] = None
can_run_dp_cuda_graph: bool = False
# For processing logprobs
return_logprob: bool = False
@@ -891,6 +892,13 @@ class ScheduleBatch:
self.seq_lens = torch.empty(0, dtype=torch.int32).to(
self.device, non_blocking=True
)
self.out_cache_loc = torch.empty(0, dtype=torch.int32).to(
self.device, non_blocking=True
)
self.req_pool_indices = torch.empty(0, dtype=torch.int32).to(
self.device, non_blocking=True
)
self.seq_lens_sum = 0
self.extend_num_tokens = 0
def prepare_for_decode(self, enable_overlap: bool = False):
@@ -1032,6 +1040,7 @@ class ScheduleBatch:
return_logprob=self.return_logprob,
top_logprobs_nums=self.top_logprobs_nums,
global_num_tokens=self.global_num_tokens,
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
extend_num_tokens=self.extend_num_tokens,
extend_seq_lens=extend_seq_lens,
extend_prefix_lens=extend_prefix_lens,
@@ -1093,6 +1102,7 @@ class ModelWorkerBatch:
# For DP attention
global_num_tokens: Optional[List[int]]
can_run_dp_cuda_graph: bool
# For extend
extend_num_tokens: Optional[int]