Support cuda graph for DP attention (#2061)
This commit is contained in:
@@ -592,6 +592,9 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
def forward_idle(self, forward_batch: ForwardBatch):
|
||||
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
|
||||
return self.cuda_graph_runner.replay(forward_batch)
|
||||
|
||||
return self.model.forward(
|
||||
forward_batch.input_ids, forward_batch.positions, forward_batch
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user