Support cuda graph for DP attention (#2061)

This commit is contained in:
Ke Bao
2024-11-18 08:29:20 +08:00
committed by GitHub
parent 11f881d173
commit 62832bb272
9 changed files with 88 additions and 26 deletions

View File

@@ -592,6 +592,9 @@ class ModelRunner:
)
def forward_idle(self, forward_batch: ForwardBatch):
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
return self.cuda_graph_runner.replay(forward_batch)
return self.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch
)