Print progress bar during cuda graph capture (#2502)
This commit is contained in:
@@ -20,6 +20,8 @@ from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from vllm.distributed.parallel_state import graph_capture
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
|
||||
@@ -255,7 +257,12 @@ class CudaGraphRunner:
|
||||
def capture(self):
|
||||
with graph_capture() as graph_capture_context:
|
||||
self.stream = graph_capture_context.stream
|
||||
for bs in self.capture_bs:
|
||||
capture_bs = (
|
||||
tqdm.tqdm(self.capture_bs)
|
||||
if get_tensor_model_parallel_rank() == 0
|
||||
else self.capture_bs
|
||||
)
|
||||
for bs in capture_bs:
|
||||
with patch_model(
|
||||
self.model_runner.model,
|
||||
bs in self.compile_bs,
|
||||
|
||||
Reference in New Issue
Block a user