Print progress bar during cuda graph capture (#2502)

This commit is contained in:
Lianmin Zheng
2024-12-17 06:20:44 -08:00
parent 1fc84cf60b
commit 21e9e63ad5
2 changed files with 21 additions and 1 deletions

View File

@@ -20,6 +20,8 @@ from contextlib import contextmanager
from typing import TYPE_CHECKING, Callable
import torch
import tqdm
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.distributed.parallel_state import graph_capture
from vllm.model_executor.custom_op import CustomOp
@@ -255,7 +257,12 @@ class CudaGraphRunner:
def capture(self):
with graph_capture() as graph_capture_context:
self.stream = graph_capture_context.stream
for bs in self.capture_bs:
capture_bs = (
tqdm.tqdm(self.capture_bs)
if get_tensor_model_parallel_rank() == 0
else self.capture_bs
)
for bs in capture_bs:
with patch_model(
self.model_runner.model,
bs in self.compile_bs,