Eager Compiler for Torch Compile (#11803)
Signed-off-by: Oasis-Git <ayw.sirius19@gmail.com>
This commit is contained in:
@@ -17,24 +17,30 @@ from torch._dispatch.python import enable_python_dispatcher
|
|||||||
|
|
||||||
from sglang.srt.compilation.compilation_config import CompilationConfig
|
from sglang.srt.compilation.compilation_config import CompilationConfig
|
||||||
from sglang.srt.compilation.compilation_counter import compilation_counter
|
from sglang.srt.compilation.compilation_counter import compilation_counter
|
||||||
from sglang.srt.compilation.compiler_interface import InductorAdaptor
|
from sglang.srt.compilation.compiler_interface import EagerAdapter, InductorAdaptor
|
||||||
from sglang.srt.compilation.cuda_piecewise_backend import CUDAPiecewiseBackend
|
from sglang.srt.compilation.cuda_piecewise_backend import CUDAPiecewiseBackend
|
||||||
from sglang.srt.compilation.pass_manager import PostGradPassManager
|
from sglang.srt.compilation.pass_manager import PostGradPassManager
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def make_compiler():
|
def make_compiler(config: CompilationConfig):
|
||||||
return InductorAdaptor()
|
if config.compiler == "eager":
|
||||||
|
return EagerAdapter()
|
||||||
|
elif config.compiler == "inductor":
|
||||||
|
return InductorAdaptor()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown compiler: {config.compiler}")
|
||||||
|
|
||||||
|
|
||||||
class CompilerManager:
|
class CompilerManager:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
config: CompilationConfig,
|
||||||
):
|
):
|
||||||
self.cache = dict()
|
self.cache = dict()
|
||||||
self.is_cache_updated = False
|
self.is_cache_updated = False
|
||||||
self.compiler = make_compiler()
|
self.compiler = make_compiler(config)
|
||||||
|
|
||||||
def compute_hash(self):
|
def compute_hash(self):
|
||||||
return self.compiler.compute_hash()
|
return self.compiler.compute_hash()
|
||||||
@@ -348,7 +354,7 @@ class SGLangBackend:
|
|||||||
self.sym_tensor_indices = []
|
self.sym_tensor_indices = []
|
||||||
self.input_buffers = []
|
self.input_buffers = []
|
||||||
|
|
||||||
self.compiler_manager = CompilerManager()
|
self.compiler_manager = CompilerManager(config)
|
||||||
self.inductor_config = {
|
self.inductor_config = {
|
||||||
"enable_auto_functionalized_v2": False,
|
"enable_auto_functionalized_v2": False,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,9 +5,10 @@ from typing import List
|
|||||||
|
|
||||||
# TODO(Yuwei): support better compile config support
|
# TODO(Yuwei): support better compile config support
|
||||||
class CompilationConfig:
|
class CompilationConfig:
|
||||||
def __init__(self, capture_sizes: List[int]):
|
def __init__(self, capture_sizes: List[int], compiler: str = "eager"):
|
||||||
self.traced_files = set()
|
self.traced_files = set()
|
||||||
self.capture_sizes = capture_sizes
|
self.capture_sizes = capture_sizes
|
||||||
|
self.compiler = compiler
|
||||||
|
|
||||||
def add_traced_file(self, file_path: str):
|
def add_traced_file(self, file_path: str):
|
||||||
self.traced_files.add(file_path)
|
self.traced_files.add(file_path)
|
||||||
|
|||||||
@@ -475,3 +475,29 @@ def set_inductor_config(config, runtime_shape):
|
|||||||
# can be beneficial
|
# can be beneficial
|
||||||
config["max_autotune"] = True
|
config["max_autotune"] = True
|
||||||
config["coordinate_descent_tuning"] = True
|
config["coordinate_descent_tuning"] = True
|
||||||
|
|
||||||
|
|
||||||
|
class EagerAdapter(CompilerInterface):
|
||||||
|
name = "eager"
|
||||||
|
|
||||||
|
def compile(
|
||||||
|
self,
|
||||||
|
graph: fx.GraphModule,
|
||||||
|
example_inputs: list[Any],
|
||||||
|
compiler_config: dict[str, Any],
|
||||||
|
runtime_shape: Optional[int] = None,
|
||||||
|
key: Optional[str] = None,
|
||||||
|
num_graphs: int = 1,
|
||||||
|
) -> tuple[Optional[Callable], Optional[Any]]:
|
||||||
|
return graph, None
|
||||||
|
|
||||||
|
def load(
|
||||||
|
self,
|
||||||
|
handle: Any,
|
||||||
|
graph: fx.GraphModule,
|
||||||
|
example_inputs: list[Any],
|
||||||
|
graph_index: int,
|
||||||
|
runtime_shape: Optional[int] = None,
|
||||||
|
num_graphs: int = 1,
|
||||||
|
) -> Callable:
|
||||||
|
raise NotImplementedError("eager compilation is not supported")
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from unittest.mock import patch
|
|||||||
import torch
|
import torch
|
||||||
import torch.fx as fx
|
import torch.fx as fx
|
||||||
|
|
||||||
|
import sglang.srt.compilation.weak_ref_tensor_jit # noqa: F401
|
||||||
from sglang.srt.compilation.compilation_config import CompilationConfig
|
from sglang.srt.compilation.compilation_config import CompilationConfig
|
||||||
from sglang.srt.compilation.compilation_counter import compilation_counter
|
from sglang.srt.compilation.compilation_counter import compilation_counter
|
||||||
|
|
||||||
|
|||||||
@@ -103,9 +103,10 @@ def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
|
|||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def patch_model(model: torch.nn.Module):
|
def patch_model(model: torch.nn.Module, compiler: str):
|
||||||
try:
|
try:
|
||||||
_to_torch(model, reverse=False, num_tokens=16)
|
if compiler != "eager":
|
||||||
|
_to_torch(model, reverse=False, num_tokens=16)
|
||||||
yield model
|
yield model
|
||||||
finally:
|
finally:
|
||||||
_to_torch(model, reverse=True, num_tokens=16)
|
_to_torch(model, reverse=True, num_tokens=16)
|
||||||
@@ -144,8 +145,13 @@ class PiecewiseCudaGraphRunner:
|
|||||||
assert (
|
assert (
|
||||||
self.model_runner.server_args.piecewise_cuda_graph_tokens is not None
|
self.model_runner.server_args.piecewise_cuda_graph_tokens is not None
|
||||||
), "piecewise_cuda_graph_tokens is not set"
|
), "piecewise_cuda_graph_tokens is not set"
|
||||||
|
assert self.model_runner.server_args.piecewise_cuda_graph_compiler in [
|
||||||
|
"eager",
|
||||||
|
"inductor",
|
||||||
|
], "By now, only eager and inductor are supported for piecewise cuda graph compiler."
|
||||||
self.compile_config = CompilationConfig(
|
self.compile_config = CompilationConfig(
|
||||||
self.model_runner.server_args.piecewise_cuda_graph_tokens
|
self.model_runner.server_args.piecewise_cuda_graph_tokens,
|
||||||
|
self.model_runner.server_args.piecewise_cuda_graph_compiler,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Batch sizes to capture
|
# Batch sizes to capture
|
||||||
@@ -179,7 +185,9 @@ class PiecewiseCudaGraphRunner:
|
|||||||
# Set graph pool id globally to be able to use symmetric memory
|
# Set graph pool id globally to be able to use symmetric memory
|
||||||
set_graph_pool_id(get_global_graph_memory_pool())
|
set_graph_pool_id(get_global_graph_memory_pool())
|
||||||
|
|
||||||
with patch_model(self.model_runner.model.model) as patched_model:
|
with patch_model(
|
||||||
|
self.model_runner.model.model, self.compile_config.compiler
|
||||||
|
) as patched_model:
|
||||||
install_torch_compiled(
|
install_torch_compiled(
|
||||||
patched_model,
|
patched_model,
|
||||||
fullgraph=True,
|
fullgraph=True,
|
||||||
@@ -191,14 +199,14 @@ class PiecewiseCudaGraphRunner:
|
|||||||
with set_compiled(True):
|
with set_compiled(True):
|
||||||
self.warmup_and_capture()
|
self.warmup_and_capture()
|
||||||
|
|
||||||
# Capture
|
# Capture
|
||||||
try:
|
try:
|
||||||
with model_capture_mode():
|
with model_capture_mode():
|
||||||
self.capture()
|
self.capture()
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Capture cuda graph failed: {e}\n{PIECEWISE_CUDA_GRAPH_CAPTURE_FAILED_MSG}"
|
f"Capture cuda graph failed: {e}\n{PIECEWISE_CUDA_GRAPH_CAPTURE_FAILED_MSG}"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.raw_num_tokens = 0
|
self.raw_num_tokens = 0
|
||||||
|
|
||||||
|
|||||||
@@ -436,6 +436,7 @@ class ServerArgs:
|
|||||||
torch_compile_max_bs: int = 32
|
torch_compile_max_bs: int = 32
|
||||||
piecewise_cuda_graph_max_tokens: int = 4096
|
piecewise_cuda_graph_max_tokens: int = 4096
|
||||||
piecewise_cuda_graph_tokens: Optional[List[int]] = None
|
piecewise_cuda_graph_tokens: Optional[List[int]] = None
|
||||||
|
piecewise_cuda_graph_compiler: str = "eager"
|
||||||
torchao_config: str = ""
|
torchao_config: str = ""
|
||||||
enable_nan_detection: bool = False
|
enable_nan_detection: bool = False
|
||||||
enable_p2p_check: bool = False
|
enable_p2p_check: bool = False
|
||||||
@@ -2815,6 +2816,13 @@ class ServerArgs:
|
|||||||
default=ServerArgs.piecewise_cuda_graph_tokens,
|
default=ServerArgs.piecewise_cuda_graph_tokens,
|
||||||
help="Set the list of tokens when using piecewise cuda graph.",
|
help="Set the list of tokens when using piecewise cuda graph.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--piecewise-cuda-graph-compiler",
|
||||||
|
type=str,
|
||||||
|
default=ServerArgs.piecewise_cuda_graph_compiler,
|
||||||
|
help="Set the compiler for piecewise cuda graph. Choices are: eager, inductor.",
|
||||||
|
choices=["eager", "inductor"],
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--torch-compile-max-bs",
|
"--torch-compile-max-bs",
|
||||||
type=int,
|
type=int,
|
||||||
|
|||||||
Reference in New Issue
Block a user