Move torch.compile configs into cuda_graph_runner.py (#993)
This commit is contained in:
@@ -71,6 +71,18 @@ def patch_model(
|
|||||||
tp_group.ca_comm = backup_ca_comm
|
tp_group.ca_comm = backup_ca_comm
|
||||||
|
|
||||||
|
|
||||||
|
def set_torch_compile_config():
|
||||||
|
import torch._dynamo.config
|
||||||
|
import torch._inductor.config
|
||||||
|
|
||||||
|
torch._inductor.config.coordinate_descent_tuning = True
|
||||||
|
torch._inductor.config.triton.unique_kernel_names = True
|
||||||
|
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
|
||||||
|
|
||||||
|
# FIXME: tmp workaround
|
||||||
|
torch._dynamo.config.accumulated_cache_size_limit = 1024
|
||||||
|
|
||||||
|
|
||||||
class CudaGraphRunner:
|
class CudaGraphRunner:
|
||||||
def __init__(self, model_runner, max_batch_size_to_capture, use_torch_compile):
|
def __init__(self, model_runner, max_batch_size_to_capture, use_torch_compile):
|
||||||
self.model_runner = model_runner
|
self.model_runner = model_runner
|
||||||
@@ -112,6 +124,9 @@ class CudaGraphRunner:
|
|||||||
|
|
||||||
self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else []
|
self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else []
|
||||||
|
|
||||||
|
if use_torch_compile:
|
||||||
|
set_torch_compile_config()
|
||||||
|
|
||||||
def can_run(self, batch_size):
|
def can_run(self, batch_size):
|
||||||
return batch_size < self.max_bs
|
return batch_size < self.max_bs
|
||||||
|
|
||||||
|
|||||||
@@ -74,7 +74,6 @@ from sglang.srt.utils import (
|
|||||||
enable_show_time_cost,
|
enable_show_time_cost,
|
||||||
kill_child_process,
|
kill_child_process,
|
||||||
maybe_set_triton_cache_manager,
|
maybe_set_triton_cache_manager,
|
||||||
set_torch_compile_config,
|
|
||||||
set_ulimit,
|
set_ulimit,
|
||||||
)
|
)
|
||||||
from sglang.utils import get_exception_traceback
|
from sglang.utils import get_exception_traceback
|
||||||
@@ -347,10 +346,6 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|||||||
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
|
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
|
||||||
maybe_set_triton_cache_manager()
|
maybe_set_triton_cache_manager()
|
||||||
|
|
||||||
# Set torch compile config
|
|
||||||
if server_args.enable_torch_compile:
|
|
||||||
set_torch_compile_config()
|
|
||||||
|
|
||||||
# Set global chat template
|
# Set global chat template
|
||||||
if server_args.chat_template:
|
if server_args.chat_template:
|
||||||
# TODO: replace this with huggingface transformers template
|
# TODO: replace this with huggingface transformers template
|
||||||
|
|||||||
@@ -622,19 +622,6 @@ def receive_addrs(model_port_args, server_args):
|
|||||||
dist.destroy_process_group()
|
dist.destroy_process_group()
|
||||||
|
|
||||||
|
|
||||||
def set_torch_compile_config():
|
|
||||||
# The following configurations are for torch compile optimizations
|
|
||||||
import torch._dynamo.config
|
|
||||||
import torch._inductor.config
|
|
||||||
|
|
||||||
torch._inductor.config.coordinate_descent_tuning = True
|
|
||||||
torch._inductor.config.triton.unique_kernel_names = True
|
|
||||||
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
|
|
||||||
|
|
||||||
# FIXME: tmp workaround
|
|
||||||
torch._dynamo.config.accumulated_cache_size_limit = 256
|
|
||||||
|
|
||||||
|
|
||||||
def set_ulimit(target_soft_limit=65535):
|
def set_ulimit(target_soft_limit=65535):
|
||||||
resource_type = resource.RLIMIT_NOFILE
|
resource_type = resource.RLIMIT_NOFILE
|
||||||
current_soft, current_hard = resource.getrlimit(resource_type)
|
current_soft, current_hard = resource.getrlimit(resource_type)
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
import threading
|
|
||||||
import traceback
|
import traceback
|
||||||
import urllib.request
|
import urllib.request
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|||||||
Reference in New Issue
Block a user