Print progress bar during cuda graph capture (#2502)
This commit is contained in:
13
docs/references/torch_compile_cache.md
Normal file
13
docs/references/torch_compile_cache.md
Normal file
@@ -0,0 +1,13 @@
|
||||
# Enabling cache for torch.compile
|
||||
|
||||
SGLang uses `max-autotune-no-cudagraphs` mode of torch.compile. The auto-tuning can be slow.
|
||||
If you want to deploy a model on many different machines, you can ship the torch.compile cache to these machines and skip the compilation steps.
|
||||
|
||||
This is based on https://pytorch.org/tutorials/recipes/torch_compile_caching_tutorial.html
|
||||
|
||||
|
||||
1. Generate the cache by setting TORCHINDUCTOR_CACHE_DIR and running the model once.
|
||||
```
|
||||
TORCHINDUCTOR_CACHE_DIR=/root/inductor_root_cache python3 -m sglang.launch_server --model meta-llama/Llama-3.1-8B-Instruct --enable-torch-compile
|
||||
```
|
||||
2. Copy the cache folder to other machines and launch the server with `TORCHINDUCTOR_CACHE_DIR`.
|
||||
@@ -20,6 +20,8 @@ from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from vllm.distributed.parallel_state import graph_capture
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
|
||||
@@ -255,7 +257,12 @@ class CudaGraphRunner:
|
||||
def capture(self):
|
||||
with graph_capture() as graph_capture_context:
|
||||
self.stream = graph_capture_context.stream
|
||||
for bs in self.capture_bs:
|
||||
capture_bs = (
|
||||
tqdm.tqdm(self.capture_bs)
|
||||
if get_tensor_model_parallel_rank() == 0
|
||||
else self.capture_bs
|
||||
)
|
||||
for bs in capture_bs:
|
||||
with patch_model(
|
||||
self.model_runner.model,
|
||||
bs in self.compile_bs,
|
||||
|
||||
Reference in New Issue
Block a user