Print progress bar during cuda graph capture (#2502)
This commit is contained in:
13
docs/references/torch_compile_cache.md
Normal file
13
docs/references/torch_compile_cache.md
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
# Enabling cache for torch.compile
|
||||||
|
|
||||||
|
SGLang uses `max-autotune-no-cudagraphs` mode of torch.compile. The auto-tuning can be slow.
|
||||||
|
If you want to deploy a model on many different machines, you can ship the torch.compile cache to these machines and skip the compilation steps.
|
||||||
|
|
||||||
|
This is based on https://pytorch.org/tutorials/recipes/torch_compile_caching_tutorial.html
|
||||||
|
|
||||||
|
|
||||||
|
1. Generate the cache by setting TORCHINDUCTOR_CACHE_DIR and running the model once.
|
||||||
|
```
|
||||||
|
TORCHINDUCTOR_CACHE_DIR=/root/inductor_root_cache python3 -m sglang.launch_server --model meta-llama/Llama-3.1-8B-Instruct --enable-torch-compile
|
||||||
|
```
|
||||||
|
2. Copy the cache folder to other machines and launch the server with `TORCHINDUCTOR_CACHE_DIR`.
|
||||||
@@ -20,6 +20,8 @@ from contextlib import contextmanager
|
|||||||
from typing import TYPE_CHECKING, Callable
|
from typing import TYPE_CHECKING, Callable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import tqdm
|
||||||
|
from vllm.distributed import get_tensor_model_parallel_rank
|
||||||
from vllm.distributed.parallel_state import graph_capture
|
from vllm.distributed.parallel_state import graph_capture
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
|
|
||||||
@@ -255,7 +257,12 @@ class CudaGraphRunner:
|
|||||||
def capture(self):
|
def capture(self):
|
||||||
with graph_capture() as graph_capture_context:
|
with graph_capture() as graph_capture_context:
|
||||||
self.stream = graph_capture_context.stream
|
self.stream = graph_capture_context.stream
|
||||||
for bs in self.capture_bs:
|
capture_bs = (
|
||||||
|
tqdm.tqdm(self.capture_bs)
|
||||||
|
if get_tensor_model_parallel_rank() == 0
|
||||||
|
else self.capture_bs
|
||||||
|
)
|
||||||
|
for bs in capture_bs:
|
||||||
with patch_model(
|
with patch_model(
|
||||||
self.model_runner.model,
|
self.model_runner.model,
|
||||||
bs in self.compile_bs,
|
bs in self.compile_bs,
|
||||||
|
|||||||
Reference in New Issue
Block a user