Fix bugs in sampler with CUDA graph / torch.compile (#1306)

This commit is contained in:
Liangsheng Yin
2024-09-02 16:18:48 -07:00
committed by GitHub
parent 2561ed012c
commit a5a134f39f
4 changed files with 48 additions and 26 deletions

View File

@@ -523,7 +523,7 @@ class ModelRunner:
if (
self.cuda_graph_runner
and self.cuda_graph_runner.can_run(len(batch.reqs))
and not batch.sampling_info.has_bias()
and batch.sampling_info.can_run_in_cuda_graph()
):
return self.cuda_graph_runner.replay(batch)