Fix bugs in sampler with CUDA graph / torch.compile (#1306)
This commit is contained in:
@@ -523,7 +523,7 @@ class ModelRunner:
|
||||
if (
|
||||
self.cuda_graph_runner
|
||||
and self.cuda_graph_runner.can_run(len(batch.reqs))
|
||||
and not batch.sampling_info.has_bias()
|
||||
and batch.sampling_info.can_run_in_cuda_graph()
|
||||
):
|
||||
return self.cuda_graph_runner.replay(batch)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user