From 9f662501a36b332ec4ac9b4ece29233ad7563c01 Mon Sep 17 00:00:00 2001 From: Ying Sheng Date: Thu, 8 Aug 2024 13:20:30 -0700 Subject: [PATCH] Move torch.compile configs into cuda_graph_runner.py (#993) --- .../srt/model_executor/cuda_graph_runner.py | 15 +++++++++++++++ python/sglang/srt/server.py | 5 ----- python/sglang/srt/utils.py | 13 ------------- python/sglang/utils.py | 1 - 4 files changed, 15 insertions(+), 19 deletions(-) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index ae6fe83c5..9bfd4a646 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -71,6 +71,18 @@ def patch_model( tp_group.ca_comm = backup_ca_comm +def set_torch_compile_config(): + import torch._dynamo.config + import torch._inductor.config + + torch._inductor.config.coordinate_descent_tuning = True + torch._inductor.config.triton.unique_kernel_names = True + torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future + + # FIXME: tmp workaround + torch._dynamo.config.accumulated_cache_size_limit = 1024 + + class CudaGraphRunner: def __init__(self, model_runner, max_batch_size_to_capture, use_torch_compile): self.model_runner = model_runner @@ -112,6 +124,9 @@ class CudaGraphRunner: self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else [] + if use_torch_compile: + set_torch_compile_config() + def can_run(self, batch_size): return batch_size < self.max_bs diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 203e6a457..0443e9f2a 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -74,7 +74,6 @@ from sglang.srt.utils import ( enable_show_time_cost, kill_child_process, maybe_set_triton_cache_manager, - set_torch_compile_config, set_ulimit, ) from sglang.utils import get_exception_traceback @@ -347,10 +346,6 @@ def _set_envs_and_config(server_args: ServerArgs): # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency. maybe_set_triton_cache_manager() - # Set torch compile config - if server_args.enable_torch_compile: - set_torch_compile_config() - # Set global chat template if server_args.chat_template: # TODO: replace this with huggingface transformers template diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 172c93c74..e15cb6751 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -622,19 +622,6 @@ def receive_addrs(model_port_args, server_args): dist.destroy_process_group() -def set_torch_compile_config(): - # The following configurations are for torch compile optimizations - import torch._dynamo.config - import torch._inductor.config - - torch._inductor.config.coordinate_descent_tuning = True - torch._inductor.config.triton.unique_kernel_names = True - torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future - - # FIXME: tmp workaround - torch._dynamo.config.accumulated_cache_size_limit = 256 - - def set_ulimit(target_soft_limit=65535): resource_type = resource.RLIMIT_NOFILE current_soft, current_hard = resource.getrlimit(resource_type) diff --git a/python/sglang/utils.py b/python/sglang/utils.py index c1193df3c..c880d259d 100644 --- a/python/sglang/utils.py +++ b/python/sglang/utils.py @@ -6,7 +6,6 @@ import json import logging import signal import sys -import threading import traceback import urllib.request from concurrent.futures import ThreadPoolExecutor