Eager Compiler for Torch Compile (#11803)
Signed-off-by: Oasis-Git <ayw.sirius19@gmail.com>
This commit is contained in:
@@ -17,24 +17,30 @@ from torch._dispatch.python import enable_python_dispatcher
|
||||
|
||||
from sglang.srt.compilation.compilation_config import CompilationConfig
|
||||
from sglang.srt.compilation.compilation_counter import compilation_counter
|
||||
from sglang.srt.compilation.compiler_interface import InductorAdaptor
|
||||
from sglang.srt.compilation.compiler_interface import EagerAdapter, InductorAdaptor
|
||||
from sglang.srt.compilation.cuda_piecewise_backend import CUDAPiecewiseBackend
|
||||
from sglang.srt.compilation.pass_manager import PostGradPassManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def make_compiler():
|
||||
return InductorAdaptor()
|
||||
def make_compiler(config: CompilationConfig):
|
||||
if config.compiler == "eager":
|
||||
return EagerAdapter()
|
||||
elif config.compiler == "inductor":
|
||||
return InductorAdaptor()
|
||||
else:
|
||||
raise ValueError(f"Unknown compiler: {config.compiler}")
|
||||
|
||||
|
||||
class CompilerManager:
|
||||
def __init__(
|
||||
self,
|
||||
config: CompilationConfig,
|
||||
):
|
||||
self.cache = dict()
|
||||
self.is_cache_updated = False
|
||||
self.compiler = make_compiler()
|
||||
self.compiler = make_compiler(config)
|
||||
|
||||
def compute_hash(self):
|
||||
return self.compiler.compute_hash()
|
||||
@@ -348,7 +354,7 @@ class SGLangBackend:
|
||||
self.sym_tensor_indices = []
|
||||
self.input_buffers = []
|
||||
|
||||
self.compiler_manager = CompilerManager()
|
||||
self.compiler_manager = CompilerManager(config)
|
||||
self.inductor_config = {
|
||||
"enable_auto_functionalized_v2": False,
|
||||
}
|
||||
|
||||
@@ -5,9 +5,10 @@ from typing import List
|
||||
|
||||
# TODO(Yuwei): support better compile config support
|
||||
class CompilationConfig:
|
||||
def __init__(self, capture_sizes: List[int]):
|
||||
def __init__(self, capture_sizes: List[int], compiler: str = "eager"):
|
||||
self.traced_files = set()
|
||||
self.capture_sizes = capture_sizes
|
||||
self.compiler = compiler
|
||||
|
||||
def add_traced_file(self, file_path: str):
|
||||
self.traced_files.add(file_path)
|
||||
|
||||
@@ -475,3 +475,29 @@ def set_inductor_config(config, runtime_shape):
|
||||
# can be beneficial
|
||||
config["max_autotune"] = True
|
||||
config["coordinate_descent_tuning"] = True
|
||||
|
||||
|
||||
class EagerAdapter(CompilerInterface):
|
||||
name = "eager"
|
||||
|
||||
def compile(
|
||||
self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
runtime_shape: Optional[int] = None,
|
||||
key: Optional[str] = None,
|
||||
num_graphs: int = 1,
|
||||
) -> tuple[Optional[Callable], Optional[Any]]:
|
||||
return graph, None
|
||||
|
||||
def load(
|
||||
self,
|
||||
handle: Any,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
runtime_shape: Optional[int] = None,
|
||||
num_graphs: int = 1,
|
||||
) -> Callable:
|
||||
raise NotImplementedError("eager compilation is not supported")
|
||||
|
||||
@@ -9,6 +9,7 @@ from unittest.mock import patch
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
|
||||
import sglang.srt.compilation.weak_ref_tensor_jit # noqa: F401
|
||||
from sglang.srt.compilation.compilation_config import CompilationConfig
|
||||
from sglang.srt.compilation.compilation_counter import compilation_counter
|
||||
|
||||
|
||||
@@ -103,9 +103,10 @@ def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
|
||||
|
||||
|
||||
@contextmanager
|
||||
def patch_model(model: torch.nn.Module):
|
||||
def patch_model(model: torch.nn.Module, compiler: str):
|
||||
try:
|
||||
_to_torch(model, reverse=False, num_tokens=16)
|
||||
if compiler != "eager":
|
||||
_to_torch(model, reverse=False, num_tokens=16)
|
||||
yield model
|
||||
finally:
|
||||
_to_torch(model, reverse=True, num_tokens=16)
|
||||
@@ -144,8 +145,13 @@ class PiecewiseCudaGraphRunner:
|
||||
assert (
|
||||
self.model_runner.server_args.piecewise_cuda_graph_tokens is not None
|
||||
), "piecewise_cuda_graph_tokens is not set"
|
||||
assert self.model_runner.server_args.piecewise_cuda_graph_compiler in [
|
||||
"eager",
|
||||
"inductor",
|
||||
], "By now, only eager and inductor are supported for piecewise cuda graph compiler."
|
||||
self.compile_config = CompilationConfig(
|
||||
self.model_runner.server_args.piecewise_cuda_graph_tokens
|
||||
self.model_runner.server_args.piecewise_cuda_graph_tokens,
|
||||
self.model_runner.server_args.piecewise_cuda_graph_compiler,
|
||||
)
|
||||
|
||||
# Batch sizes to capture
|
||||
@@ -179,7 +185,9 @@ class PiecewiseCudaGraphRunner:
|
||||
# Set graph pool id globally to be able to use symmetric memory
|
||||
set_graph_pool_id(get_global_graph_memory_pool())
|
||||
|
||||
with patch_model(self.model_runner.model.model) as patched_model:
|
||||
with patch_model(
|
||||
self.model_runner.model.model, self.compile_config.compiler
|
||||
) as patched_model:
|
||||
install_torch_compiled(
|
||||
patched_model,
|
||||
fullgraph=True,
|
||||
@@ -191,14 +199,14 @@ class PiecewiseCudaGraphRunner:
|
||||
with set_compiled(True):
|
||||
self.warmup_and_capture()
|
||||
|
||||
# Capture
|
||||
try:
|
||||
with model_capture_mode():
|
||||
self.capture()
|
||||
except RuntimeError as e:
|
||||
raise Exception(
|
||||
f"Capture cuda graph failed: {e}\n{PIECEWISE_CUDA_GRAPH_CAPTURE_FAILED_MSG}"
|
||||
)
|
||||
# Capture
|
||||
try:
|
||||
with model_capture_mode():
|
||||
self.capture()
|
||||
except RuntimeError as e:
|
||||
raise Exception(
|
||||
f"Capture cuda graph failed: {e}\n{PIECEWISE_CUDA_GRAPH_CAPTURE_FAILED_MSG}"
|
||||
)
|
||||
|
||||
self.raw_num_tokens = 0
|
||||
|
||||
|
||||
@@ -436,6 +436,7 @@ class ServerArgs:
|
||||
torch_compile_max_bs: int = 32
|
||||
piecewise_cuda_graph_max_tokens: int = 4096
|
||||
piecewise_cuda_graph_tokens: Optional[List[int]] = None
|
||||
piecewise_cuda_graph_compiler: str = "eager"
|
||||
torchao_config: str = ""
|
||||
enable_nan_detection: bool = False
|
||||
enable_p2p_check: bool = False
|
||||
@@ -2815,6 +2816,13 @@ class ServerArgs:
|
||||
default=ServerArgs.piecewise_cuda_graph_tokens,
|
||||
help="Set the list of tokens when using piecewise cuda graph.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--piecewise-cuda-graph-compiler",
|
||||
type=str,
|
||||
default=ServerArgs.piecewise_cuda_graph_compiler,
|
||||
help="Set the compiler for piecewise cuda graph. Choices are: eager, inductor.",
|
||||
choices=["eager", "inductor"],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--torch-compile-max-bs",
|
||||
type=int,
|
||||
|
||||
Reference in New Issue
Block a user