From 1d726528f7d09e63d65445ad03c2f6d6948a3525 Mon Sep 17 00:00:00 2001 From: Yuwei An Date: Sat, 18 Oct 2025 00:18:52 -0700 Subject: [PATCH] Eager Compiler for Torch Compile (#11803) Signed-off-by: Oasis-Git --- python/sglang/srt/compilation/backend.py | 16 +++++++--- .../srt/compilation/compilation_config.py | 3 +- .../srt/compilation/compiler_interface.py | 26 +++++++++++++++ .../srt/compilation/cuda_piecewise_backend.py | 1 + .../piecewise_cuda_graph_runner.py | 32 ++++++++++++------- python/sglang/srt/server_args.py | 8 +++++ 6 files changed, 68 insertions(+), 18 deletions(-) diff --git a/python/sglang/srt/compilation/backend.py b/python/sglang/srt/compilation/backend.py index 88171a124..7eb650305 100644 --- a/python/sglang/srt/compilation/backend.py +++ b/python/sglang/srt/compilation/backend.py @@ -17,24 +17,30 @@ from torch._dispatch.python import enable_python_dispatcher from sglang.srt.compilation.compilation_config import CompilationConfig from sglang.srt.compilation.compilation_counter import compilation_counter -from sglang.srt.compilation.compiler_interface import InductorAdaptor +from sglang.srt.compilation.compiler_interface import EagerAdapter, InductorAdaptor from sglang.srt.compilation.cuda_piecewise_backend import CUDAPiecewiseBackend from sglang.srt.compilation.pass_manager import PostGradPassManager logger = logging.getLogger(__name__) -def make_compiler(): - return InductorAdaptor() +def make_compiler(config: CompilationConfig): + if config.compiler == "eager": + return EagerAdapter() + elif config.compiler == "inductor": + return InductorAdaptor() + else: + raise ValueError(f"Unknown compiler: {config.compiler}") class CompilerManager: def __init__( self, + config: CompilationConfig, ): self.cache = dict() self.is_cache_updated = False - self.compiler = make_compiler() + self.compiler = make_compiler(config) def compute_hash(self): return self.compiler.compute_hash() @@ -348,7 +354,7 @@ class SGLangBackend: self.sym_tensor_indices = [] self.input_buffers = [] - self.compiler_manager = CompilerManager() + self.compiler_manager = CompilerManager(config) self.inductor_config = { "enable_auto_functionalized_v2": False, } diff --git a/python/sglang/srt/compilation/compilation_config.py b/python/sglang/srt/compilation/compilation_config.py index 7a8ef6436..fbf1493e1 100644 --- a/python/sglang/srt/compilation/compilation_config.py +++ b/python/sglang/srt/compilation/compilation_config.py @@ -5,9 +5,10 @@ from typing import List # TODO(Yuwei): support better compile config support class CompilationConfig: - def __init__(self, capture_sizes: List[int]): + def __init__(self, capture_sizes: List[int], compiler: str = "eager"): self.traced_files = set() self.capture_sizes = capture_sizes + self.compiler = compiler def add_traced_file(self, file_path: str): self.traced_files.add(file_path) diff --git a/python/sglang/srt/compilation/compiler_interface.py b/python/sglang/srt/compilation/compiler_interface.py index 0c58a0dea..8310f75c9 100644 --- a/python/sglang/srt/compilation/compiler_interface.py +++ b/python/sglang/srt/compilation/compiler_interface.py @@ -475,3 +475,29 @@ def set_inductor_config(config, runtime_shape): # can be beneficial config["max_autotune"] = True config["coordinate_descent_tuning"] = True + + +class EagerAdapter(CompilerInterface): + name = "eager" + + def compile( + self, + graph: fx.GraphModule, + example_inputs: list[Any], + compiler_config: dict[str, Any], + runtime_shape: Optional[int] = None, + key: Optional[str] = None, + num_graphs: int = 1, + ) -> tuple[Optional[Callable], Optional[Any]]: + return graph, None + + def load( + self, + handle: Any, + graph: fx.GraphModule, + example_inputs: list[Any], + graph_index: int, + runtime_shape: Optional[int] = None, + num_graphs: int = 1, + ) -> Callable: + raise NotImplementedError("eager compilation is not supported") diff --git a/python/sglang/srt/compilation/cuda_piecewise_backend.py b/python/sglang/srt/compilation/cuda_piecewise_backend.py index 44e3803ff..b96755d4f 100644 --- a/python/sglang/srt/compilation/cuda_piecewise_backend.py +++ b/python/sglang/srt/compilation/cuda_piecewise_backend.py @@ -9,6 +9,7 @@ from unittest.mock import patch import torch import torch.fx as fx +import sglang.srt.compilation.weak_ref_tensor_jit # noqa: F401 from sglang.srt.compilation.compilation_config import CompilationConfig from sglang.srt.compilation.compilation_counter import compilation_counter diff --git a/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py b/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py index e4f9002b7..922642e56 100644 --- a/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py @@ -103,9 +103,10 @@ def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int): @contextmanager -def patch_model(model: torch.nn.Module): +def patch_model(model: torch.nn.Module, compiler: str): try: - _to_torch(model, reverse=False, num_tokens=16) + if compiler != "eager": + _to_torch(model, reverse=False, num_tokens=16) yield model finally: _to_torch(model, reverse=True, num_tokens=16) @@ -144,8 +145,13 @@ class PiecewiseCudaGraphRunner: assert ( self.model_runner.server_args.piecewise_cuda_graph_tokens is not None ), "piecewise_cuda_graph_tokens is not set" + assert self.model_runner.server_args.piecewise_cuda_graph_compiler in [ + "eager", + "inductor", + ], "By now, only eager and inductor are supported for piecewise cuda graph compiler." self.compile_config = CompilationConfig( - self.model_runner.server_args.piecewise_cuda_graph_tokens + self.model_runner.server_args.piecewise_cuda_graph_tokens, + self.model_runner.server_args.piecewise_cuda_graph_compiler, ) # Batch sizes to capture @@ -179,7 +185,9 @@ class PiecewiseCudaGraphRunner: # Set graph pool id globally to be able to use symmetric memory set_graph_pool_id(get_global_graph_memory_pool()) - with patch_model(self.model_runner.model.model) as patched_model: + with patch_model( + self.model_runner.model.model, self.compile_config.compiler + ) as patched_model: install_torch_compiled( patched_model, fullgraph=True, @@ -191,14 +199,14 @@ class PiecewiseCudaGraphRunner: with set_compiled(True): self.warmup_and_capture() - # Capture - try: - with model_capture_mode(): - self.capture() - except RuntimeError as e: - raise Exception( - f"Capture cuda graph failed: {e}\n{PIECEWISE_CUDA_GRAPH_CAPTURE_FAILED_MSG}" - ) + # Capture + try: + with model_capture_mode(): + self.capture() + except RuntimeError as e: + raise Exception( + f"Capture cuda graph failed: {e}\n{PIECEWISE_CUDA_GRAPH_CAPTURE_FAILED_MSG}" + ) self.raw_num_tokens = 0 diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index b6074f86b..e363dbf7d 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -436,6 +436,7 @@ class ServerArgs: torch_compile_max_bs: int = 32 piecewise_cuda_graph_max_tokens: int = 4096 piecewise_cuda_graph_tokens: Optional[List[int]] = None + piecewise_cuda_graph_compiler: str = "eager" torchao_config: str = "" enable_nan_detection: bool = False enable_p2p_check: bool = False @@ -2815,6 +2816,13 @@ class ServerArgs: default=ServerArgs.piecewise_cuda_graph_tokens, help="Set the list of tokens when using piecewise cuda graph.", ) + parser.add_argument( + "--piecewise-cuda-graph-compiler", + type=str, + default=ServerArgs.piecewise_cuda_graph_compiler, + help="Set the compiler for piecewise cuda graph. Choices are: eager, inductor.", + choices=["eager", "inductor"], + ) parser.add_argument( "--torch-compile-max-bs", type=int,