From 21e9e63ad56f8bd25663fa6907ed92f47a2b2724 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 17 Dec 2024 06:20:44 -0800 Subject: [PATCH] Print progress bar during cuda graph capture (#2502) --- docs/references/torch_compile_cache.md | 13 +++++++++++++ .../sglang/srt/model_executor/cuda_graph_runner.py | 9 ++++++++- 2 files changed, 21 insertions(+), 1 deletion(-) create mode 100644 docs/references/torch_compile_cache.md diff --git a/docs/references/torch_compile_cache.md b/docs/references/torch_compile_cache.md new file mode 100644 index 000000000..f2bb257f4 --- /dev/null +++ b/docs/references/torch_compile_cache.md @@ -0,0 +1,13 @@ +# Enabling cache for torch.compile + +SGLang uses `max-autotune-no-cudagraphs` mode of torch.compile. The auto-tuning can be slow. +If you want to deploy a model on many different machines, you can ship the torch.compile cache to these machines and skip the compilation steps. + +This is based on https://pytorch.org/tutorials/recipes/torch_compile_caching_tutorial.html + + +1. Generate the cache by setting TORCHINDUCTOR_CACHE_DIR and running the model once. +``` +TORCHINDUCTOR_CACHE_DIR=/root/inductor_root_cache python3 -m sglang.launch_server --model meta-llama/Llama-3.1-8B-Instruct --enable-torch-compile +``` +2. Copy the cache folder to other machines and launch the server with `TORCHINDUCTOR_CACHE_DIR`. diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index a113fb9c0..1efd65e57 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -20,6 +20,8 @@ from contextlib import contextmanager from typing import TYPE_CHECKING, Callable import torch +import tqdm +from vllm.distributed import get_tensor_model_parallel_rank from vllm.distributed.parallel_state import graph_capture from vllm.model_executor.custom_op import CustomOp @@ -255,7 +257,12 @@ class CudaGraphRunner: def capture(self): with graph_capture() as graph_capture_context: self.stream = graph_capture_context.stream - for bs in self.capture_bs: + capture_bs = ( + tqdm.tqdm(self.capture_bs) + if get_tensor_model_parallel_rank() == 0 + else self.capture_bs + ) + for bs in capture_bs: with patch_model( self.model_runner.model, bs in self.compile_bs,