Move torch.compile configs into cuda_graph_runner.py (#993)

This commit is contained in:
Ying Sheng
2024-08-08 13:20:30 -07:00
committed by GitHub
parent ab7875941b
commit 9f662501a3
4 changed files with 15 additions and 19 deletions

View File

@@ -71,6 +71,18 @@ def patch_model(
tp_group.ca_comm = backup_ca_comm tp_group.ca_comm = backup_ca_comm
def set_torch_compile_config():
import torch._dynamo.config
import torch._inductor.config
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
# FIXME: tmp workaround
torch._dynamo.config.accumulated_cache_size_limit = 1024
class CudaGraphRunner: class CudaGraphRunner:
def __init__(self, model_runner, max_batch_size_to_capture, use_torch_compile): def __init__(self, model_runner, max_batch_size_to_capture, use_torch_compile):
self.model_runner = model_runner self.model_runner = model_runner
@@ -112,6 +124,9 @@ class CudaGraphRunner:
self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else [] self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else []
if use_torch_compile:
set_torch_compile_config()
def can_run(self, batch_size): def can_run(self, batch_size):
return batch_size < self.max_bs return batch_size < self.max_bs

View File

@@ -74,7 +74,6 @@ from sglang.srt.utils import (
enable_show_time_cost, enable_show_time_cost,
kill_child_process, kill_child_process,
maybe_set_triton_cache_manager, maybe_set_triton_cache_manager,
set_torch_compile_config,
set_ulimit, set_ulimit,
) )
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
@@ -347,10 +346,6 @@ def _set_envs_and_config(server_args: ServerArgs):
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency. # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
maybe_set_triton_cache_manager() maybe_set_triton_cache_manager()
# Set torch compile config
if server_args.enable_torch_compile:
set_torch_compile_config()
# Set global chat template # Set global chat template
if server_args.chat_template: if server_args.chat_template:
# TODO: replace this with huggingface transformers template # TODO: replace this with huggingface transformers template

View File

@@ -622,19 +622,6 @@ def receive_addrs(model_port_args, server_args):
dist.destroy_process_group() dist.destroy_process_group()
def set_torch_compile_config():
# The following configurations are for torch compile optimizations
import torch._dynamo.config
import torch._inductor.config
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
# FIXME: tmp workaround
torch._dynamo.config.accumulated_cache_size_limit = 256
def set_ulimit(target_soft_limit=65535): def set_ulimit(target_soft_limit=65535):
resource_type = resource.RLIMIT_NOFILE resource_type = resource.RLIMIT_NOFILE
current_soft, current_hard = resource.getrlimit(resource_type) current_soft, current_hard = resource.getrlimit(resource_type)

View File

@@ -6,7 +6,6 @@ import json
import logging import logging
import signal import signal
import sys import sys
import threading
import traceback import traceback
import urllib.request import urllib.request
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor