Move torch.compile configs into cuda_graph_runner.py (#993)
This commit is contained in:
@@ -71,6 +71,18 @@ def patch_model(
|
||||
tp_group.ca_comm = backup_ca_comm
|
||||
|
||||
|
||||
def set_torch_compile_config():
|
||||
import torch._dynamo.config
|
||||
import torch._inductor.config
|
||||
|
||||
torch._inductor.config.coordinate_descent_tuning = True
|
||||
torch._inductor.config.triton.unique_kernel_names = True
|
||||
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
|
||||
|
||||
# FIXME: tmp workaround
|
||||
torch._dynamo.config.accumulated_cache_size_limit = 1024
|
||||
|
||||
|
||||
class CudaGraphRunner:
|
||||
def __init__(self, model_runner, max_batch_size_to_capture, use_torch_compile):
|
||||
self.model_runner = model_runner
|
||||
@@ -112,6 +124,9 @@ class CudaGraphRunner:
|
||||
|
||||
self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else []
|
||||
|
||||
if use_torch_compile:
|
||||
set_torch_compile_config()
|
||||
|
||||
def can_run(self, batch_size):
|
||||
return batch_size < self.max_bs
|
||||
|
||||
|
||||
@@ -74,7 +74,6 @@ from sglang.srt.utils import (
|
||||
enable_show_time_cost,
|
||||
kill_child_process,
|
||||
maybe_set_triton_cache_manager,
|
||||
set_torch_compile_config,
|
||||
set_ulimit,
|
||||
)
|
||||
from sglang.utils import get_exception_traceback
|
||||
@@ -347,10 +346,6 @@ def _set_envs_and_config(server_args: ServerArgs):
|
||||
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
|
||||
maybe_set_triton_cache_manager()
|
||||
|
||||
# Set torch compile config
|
||||
if server_args.enable_torch_compile:
|
||||
set_torch_compile_config()
|
||||
|
||||
# Set global chat template
|
||||
if server_args.chat_template:
|
||||
# TODO: replace this with huggingface transformers template
|
||||
|
||||
@@ -622,19 +622,6 @@ def receive_addrs(model_port_args, server_args):
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def set_torch_compile_config():
|
||||
# The following configurations are for torch compile optimizations
|
||||
import torch._dynamo.config
|
||||
import torch._inductor.config
|
||||
|
||||
torch._inductor.config.coordinate_descent_tuning = True
|
||||
torch._inductor.config.triton.unique_kernel_names = True
|
||||
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
|
||||
|
||||
# FIXME: tmp workaround
|
||||
torch._dynamo.config.accumulated_cache_size_limit = 256
|
||||
|
||||
|
||||
def set_ulimit(target_soft_limit=65535):
|
||||
resource_type = resource.RLIMIT_NOFILE
|
||||
current_soft, current_hard = resource.getrlimit(resource_type)
|
||||
|
||||
@@ -6,7 +6,6 @@ import json
|
||||
import logging
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
import traceback
|
||||
import urllib.request
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
Reference in New Issue
Block a user