Support cuda graph for DP attention (#2061)
This commit is contained in:
@@ -455,6 +455,7 @@ class ScheduleBatch:
|
||||
|
||||
# For DP attention
|
||||
global_num_tokens: Optional[List[int]] = None
|
||||
can_run_dp_cuda_graph: bool = False
|
||||
|
||||
# For processing logprobs
|
||||
return_logprob: bool = False
|
||||
@@ -891,6 +892,13 @@ class ScheduleBatch:
|
||||
self.seq_lens = torch.empty(0, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
self.out_cache_loc = torch.empty(0, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
self.req_pool_indices = torch.empty(0, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
self.seq_lens_sum = 0
|
||||
self.extend_num_tokens = 0
|
||||
|
||||
def prepare_for_decode(self, enable_overlap: bool = False):
|
||||
@@ -1032,6 +1040,7 @@ class ScheduleBatch:
|
||||
return_logprob=self.return_logprob,
|
||||
top_logprobs_nums=self.top_logprobs_nums,
|
||||
global_num_tokens=self.global_num_tokens,
|
||||
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
|
||||
extend_num_tokens=self.extend_num_tokens,
|
||||
extend_seq_lens=extend_seq_lens,
|
||||
extend_prefix_lens=extend_prefix_lens,
|
||||
@@ -1093,6 +1102,7 @@ class ModelWorkerBatch:
|
||||
|
||||
# For DP attention
|
||||
global_num_tokens: Optional[List[int]]
|
||||
can_run_dp_cuda_graph: bool
|
||||
|
||||
# For extend
|
||||
extend_num_tokens: Optional[int]
|
||||
|
||||
Reference in New Issue
Block a user