diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py index 0a506d35f..6836c9bc9 100644 --- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py +++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py @@ -185,7 +185,7 @@ class CustomAllreduce: # is enough for 131072 such tuples. The largest model I've seen only # needs less than 10000 of registered tuples. self.rank_data = torch.empty( - 8 * 1024 * 1024, dtype=torch.uint8, device=self.device + max_size, dtype=torch.uint8, device=self.device ) self._ptr = ops.init_custom_ar( self.meta_ptrs, self.rank_data, rank, self.full_nvlink @@ -202,7 +202,7 @@ class CustomAllreduce: ) handles, offsets = self._gather_ipc_meta(shard_data) self.rank_data = torch.empty( - 8 * 1024 * 1024, dtype=torch.uint8, device=self.device + max_size, dtype=torch.uint8, device=self.device ) self._ptr = ops.init_custom_ar( self.meta, self.rank_data, handles, offsets, rank, self.full_nvlink diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 009aba52e..daf18e68c 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -239,6 +239,7 @@ class GroupCoordinator: use_npu_communicator: bool, use_message_queue_broadcaster: bool = False, group_name: Optional[str] = None, + torch_compile: Optional[bool] = None, ): # Set group info group_name = group_name or "anonymous" @@ -326,10 +327,18 @@ class GroupCoordinator: self.qr_comm: Optional[QuickAllReduce] = None if use_custom_allreduce and self.world_size > 1: # Initialize a custom fast all-reduce implementation. + if torch_compile is not None and torch_compile: + # For piecewise CUDA graph, the requirement for custom allreduce is larger to + # avoid illegal cuda memory access. + ca_max_size = 256 * 1024 * 1024 + else: + ca_max_size = 8 * 1024 * 1024 try: + # print(f"ca_max_size: {ca_max_size}") self.ca_comm = CustomAllreduce( group=self.cpu_group, device=self.device, + max_size=ca_max_size, ) except Exception as e: logger.warning( @@ -1260,6 +1269,7 @@ def init_model_parallel_group( group_name: Optional[str] = None, use_mscclpp_allreduce: Optional[bool] = None, use_symm_mem_allreduce: Optional[bool] = None, + torch_compile: Optional[bool] = None, ) -> GroupCoordinator: if use_custom_allreduce is None: use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE @@ -1280,6 +1290,7 @@ def init_model_parallel_group( use_npu_communicator=True, use_message_queue_broadcaster=use_message_queue_broadcaster, group_name=group_name, + torch_compile=torch_compile, ) @@ -1439,6 +1450,7 @@ def initialize_model_parallel( pipeline_model_parallel_size: int = 1, backend: Optional[str] = None, duplicate_tp_group: bool = False, + torch_compile: Optional[bool] = None, ) -> None: """ Initialize model parallel groups. @@ -1494,6 +1506,7 @@ def initialize_model_parallel( "SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true" ), group_name="tp", + torch_compile=torch_compile, ) if duplicate_tp_group: @@ -1509,6 +1522,7 @@ def initialize_model_parallel( "SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true" ), group_name="pdmux_prefill_tp", + torch_compile=torch_compile, ) _TP.pynccl_comm.disabled = False _PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False @@ -1518,7 +1532,6 @@ def initialize_model_parallel( global _MOE_EP assert _MOE_EP is None, "expert model parallel group is already initialized" - if moe_ep_size == tensor_model_parallel_size: _MOE_EP = _TP else: @@ -1539,7 +1552,6 @@ def initialize_model_parallel( global _MOE_TP assert _MOE_TP is None, "expert model parallel group is already initialized" - if moe_tp_size == tensor_model_parallel_size: _MOE_TP = _TP else: diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 87a392d55..399ef3e71 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -43,11 +43,16 @@ _is_cpu = is_cpu() _is_xpu = is_xpu() if _is_cuda: - if _is_flashinfer_available: - from flashinfer.norm import fused_add_rmsnorm - else: - from sgl_kernel import fused_add_rmsnorm - from sgl_kernel import gemma_fused_add_rmsnorm, gemma_rmsnorm, rmsnorm + # if _is_flashinfer_available: + # from flashinfer.norm import fused_add_rmsnorm + # else: + from sgl_kernel import ( + fused_add_rmsnorm, + gemma_fused_add_rmsnorm, + gemma_rmsnorm, + rmsnorm, + ) + if _use_aiter: from aiter import rmsnorm2d_fwd as rms_norm diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 0719cdd29..bd5866137 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -17,12 +17,18 @@ from __future__ import annotations from enum import Enum from typing import TYPE_CHECKING, Optional +import torch from torch import nn if TYPE_CHECKING: from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_executor.compilation.piecewise_context_manager import ( + get_forward_context, +) +from sglang.srt.utils import direct_register_custom_op + class AttentionType(Enum): """ @@ -105,12 +111,58 @@ class RadixAttention(nn.Module): else: k = k.view(-1, self.tp_k_head_num, self.v_head_dim) - return forward_batch.attn_backend.forward( - q, - k, - v, - self, - forward_batch, - save_kv_cache, - **kwargs, - ) + if forward_batch.forward_mode.is_extend() and get_forward_context() is not None: + output = torch.zeros_like(q) + torch.ops.sglang.unified_attention_with_output( + q, k, v, output, save_kv_cache, self.layer_id + ) + return output + else: + return forward_batch.attn_backend.forward( + q, + k, + v, + self, + forward_batch, + save_kv_cache, + **kwargs, + ) + + +def unified_attention_with_output( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + save_kv_cache: bool, + layer_id: int, +) -> None: + context = get_forward_context() + forward_batch = context.forward_batch + attention_layers = context.attention_layers + attention_layer = attention_layers[layer_id] + ret = forward_batch.attn_backend.forward( + query, key, value, attention_layer, forward_batch, save_kv_cache + ) + assert output.shape == ret.shape + output.copy_(ret) + return + + +def unified_attention_with_output_fake( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + save_kv_cache: bool, + layer_id: int, +) -> None: + return + + +direct_register_custom_op( + op_name="unified_attention_with_output", + op_func=unified_attention_with_output, + mutates_args=["output"], + fake_impl=unified_attention_with_output_fake, +) diff --git a/python/sglang/srt/model_executor/compilation/backend.py b/python/sglang/srt/model_executor/compilation/backend.py new file mode 100644 index 000000000..031e40fd4 --- /dev/null +++ b/python/sglang/srt/model_executor/compilation/backend.py @@ -0,0 +1,435 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/backend.py + + +import ast +import dataclasses +import logging +import os +import pprint +import time +from collections.abc import Sequence +from contextlib import contextmanager +from typing import Any, Callable, Optional + +import torch +import torch.fx as fx +from torch._dispatch.python import enable_python_dispatcher + +from sglang.srt.model_executor.compilation.compilation_config import CompilationConfig +from sglang.srt.model_executor.compilation.compilation_counter import ( + compilation_counter, +) +from sglang.srt.model_executor.compilation.compiler_interface import InductorAdaptor +from sglang.srt.model_executor.compilation.cuda_piecewise_backend import ( + CUDAPiecewiseBackend, +) +from sglang.srt.model_executor.compilation.pass_manager import PostGradPassManager + +logger = logging.getLogger(__name__) + + +def make_compiler(): + return InductorAdaptor() + + +class CompilerManager: + def __init__( + self, + ): + self.cache = dict() + self.is_cache_updated = False + self.compiler = make_compiler() + + def compute_hash(self): + return self.compiler.compute_hash() + + def initialize_cache( + self, cache_dir: str, disable_cache: bool = False, prefix: str = "" + ): + self.disable_cache = disable_cache + self.cache_dir = cache_dir + self.cache_file_path = os.path.join(cache_dir, "sglang_compile_cache.py") + + if not disable_cache and os.path.exists(self.cache_file_path): + with open(self.cache_file_path) as f: + self.cache = ast.literal_eval(f.read()) + + self.compiler.initialize_cache( + cache_dir=cache_dir, disable_cache=disable_cache, prefix=prefix + ) + + def save_to_file(self): + if self.disable_cache or not self.is_cache_updated: + return + printer = pprint.PrettyPrinter(indent=4) + data = printer.pformat(self.cache) + with open(self.cache_file_path, "w") as f: + f.write(data) + + def load( + self, + graph: fx.GraphModule, + example_inputs: list[Any], + graph_index: int, + runtime_shape: Optional[int] = None, + ) -> Optional[Callable]: + handle = self.cache[(runtime_shape, graph_index, self.compiler.name)] + compiled_graph = self.compiler.load( + handle, graph, example_inputs, graph_index, runtime_shape + ) + if runtime_shape is None: + logger.debug( + "Directly load the %s-th graph for dynamic shape from %s via " + "handle %s", + graph_index, + self.compiler.name, + handle, + ) + else: + logger.debug( + "Directly load the %s-th graph for shape %s from %s via " "handle %s", + graph_index, + str(runtime_shape), + self.compiler.name, + handle, + ) + return compiled_graph + + def compile( + self, + graph: fx.GraphModule, + example_inputs, + inductor_config: dict[str, Any], + graph_index: int = 0, + num_graphs: int = 1, + runtime_shape: Optional[int] = None, + ) -> Any: + if graph_index == 0: + # before compiling the first graph, record the start time + global compilation_start_time + compilation_start_time = time.time() + + compilation_counter.num_backend_compilations += 1 + + compiled_graph = None + + # TODO(Yuwei): support cache loading + + # no compiler cached the graph, or the cache is disabled, + # we need to compile it + if isinstance(self.compiler, InductorAdaptor): + maybe_key = None + else: + maybe_key = f"artifact_shape_{runtime_shape}_subgraph_{graph_index}" + compiled_graph, handle = self.compiler.compile( + graph, example_inputs, inductor_config, runtime_shape, maybe_key + ) + + assert compiled_graph is not None, "Failed to compile the graph" + + # store the artifact in the cache + if handle is not None: + self.cache[(runtime_shape, graph_index, self.compiler.name)] = handle + compilation_counter.num_cache_entries_updated += 1 + self.is_cache_updated = True + if graph_index == 0: + # adds some info logging for the first graph + if runtime_shape is None: + logger.info("Cache the graph for dynamic shape for later use") + else: + logger.info( + "Cache the graph of shape %s for later use", str(runtime_shape) + ) + if runtime_shape is None: + logger.debug( + "Store the %s-th graph for dynamic shape from %s via " "handle %s", + graph_index, + self.compiler.name, + handle, + ) + else: + logger.debug( + "Store the %s-th graph for shape %s from %s via handle %s", + graph_index, + str(runtime_shape), + self.compiler.name, + handle, + ) + + # after compiling the last graph, record the end time + if graph_index == num_graphs - 1: + now = time.time() + elapsed = now - compilation_start_time + if runtime_shape is None: + logger.info("Compiling a graph for dynamic shape takes %.2f s", elapsed) + else: + logger.info( + "Compiling a graph for shape %s takes %.2f s", + runtime_shape, + elapsed, + ) + + return compiled_graph + + +@dataclasses.dataclass +class SplitItem: + submod_name: str + graph_id: int + is_splitting_graph: bool + graph: fx.GraphModule + + +def split_graph( + graph: fx.GraphModule, ops: list[str] +) -> tuple[fx.GraphModule, list[SplitItem]]: + # split graph by ops + subgraph_id = 0 + node_to_subgraph_id = {} + split_op_graphs = [] + for node in graph.graph.nodes: + if node.op in ("output", "placeholder"): + continue + if node.op == "call_function" and str(node.target) in ops: + subgraph_id += 1 + node_to_subgraph_id[node] = subgraph_id + split_op_graphs.append(subgraph_id) + subgraph_id += 1 + else: + node_to_subgraph_id[node] = subgraph_id + + # `keep_original_order` is important! + # otherwise pytorch might reorder the nodes and + # the semantics of the graph will change when we + # have mutations in the graph + split_gm = torch.fx.passes.split_module.split_module( + graph, None, lambda node: node_to_subgraph_id[node], keep_original_order=True + ) + + outputs = [] + + names = [name for (name, module) in split_gm.named_modules()] + + for name in names: + if "." in name or name == "": + # recursive child module or the root module + continue + + module = getattr(split_gm, name) + + graph_id = int(name.replace("submod_", "")) + outputs.append(SplitItem(name, graph_id, (graph_id in split_op_graphs), module)) + + # sort by intetger graph_id, rather than string name + outputs.sort(key=lambda x: x.graph_id) + + return split_gm, outputs + + +# we share the global graph pool among all the backends +global_graph_pool = None + +compilation_start_time = 0.0 + + +class PiecewiseCompileInterpreter(torch.fx.Interpreter): + def __init__( + self, + module: torch.fx.GraphModule, + compile_submod_names: list[str], + inductor_config: dict[str, Any], + graph_pool, + compile_config: CompilationConfig, + sglang_backend: "SGLangBackend", + ): + super().__init__(module) + from torch._guards import detect_fake_mode + + self.fake_mode = detect_fake_mode() + self.compile_submod_names = compile_submod_names + self.graph_pool = graph_pool + self.sglang_backend = sglang_backend + # When True, it annoyingly dumps the torch.fx.Graph on errors. + self.extra_traceback = False + self.inductor_config = inductor_config + self.compile_config = compile_config + + def run(self, *args): + fake_args = [ + self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t + for t in args + ] + with self.fake_mode, enable_python_dispatcher(): + return super().run(*fake_args) + + def call_module( + self, + target: torch.fx.node.Target, + args: tuple[torch.fx.node.Argument, ...], + kwargs: dict[str, Any], + ) -> Any: + assert isinstance(target, str) + output = super().call_module(target, args, kwargs) + + if target in self.compile_submod_names: + index = self.compile_submod_names.index(target) + submod = self.fetch_attr(target) + sym_shape_indices = [ + i for i, x in enumerate(args) if isinstance(x, torch.SymInt) + ] + global compilation_start_time + compiled_graph_for_dynamic_shape = ( + self.sglang_backend.compiler_manager.compile( + submod, + args, + self.inductor_config, + graph_index=index, + num_graphs=len(self.compile_submod_names), + runtime_shape=None, + ) + ) + + self.module.__dict__[target] = CUDAPiecewiseBackend( + submod, + self.compile_config, + self.inductor_config, + self.graph_pool, + index, + len(self.compile_submod_names), + sym_shape_indices, + compiled_graph_for_dynamic_shape, + self.sglang_backend, + ) + + compilation_counter.num_piecewise_capturable_graphs_seen += 1 + + return output + + +model_tag: str = "backbone" + + +@contextmanager +def set_model_tag(tag: str): + """Context manager to set the model tag.""" + global model_tag + assert ( + tag != model_tag + ), f"Model tag {tag} is the same as the current tag {model_tag}." + old_tag = model_tag + model_tag = tag + try: + yield + finally: + model_tag = old_tag + + +class SGLangBackend: + + graph_pool: Any + _called: bool = False + # the graph we compiled + graph: fx.GraphModule + # the stiching graph module for all the piecewise graphs + split_gm: fx.GraphModule + piecewise_graphs: list[SplitItem] + returned_callable: Callable + # Inductor passes to run on the graph pre-defunctionalization + post_grad_passes: Sequence[Callable] + sym_tensor_indices: list[int] + input_buffers: list[torch.Tensor] + compiler_manager: CompilerManager + + def __init__( + self, + config: CompilationConfig, + graph_pool: Any, + ): + assert graph_pool is not None + self.graph_pool = graph_pool + + self.post_grad_pass_manager = PostGradPassManager() + self.sym_tensor_indices = [] + self.input_buffers = [] + + self.compiler_manager = CompilerManager() + self.inductor_config = { + "enable_auto_functionalized_v2": False, + } + self.compile_config = config + + def configure_post_pass(self): + self.post_grad_pass_manager.configure() + self.inductor_config["post_grad_custom_post_pass"] = self.post_grad_pass_manager + + def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: + base_cache_dir = os.path.expanduser( + os.getenv("SGLANG_CACHE_DIR", "~/.cache/sglang/") + ) + + cache_hash = self.compiler_manager.compute_hash() + cache_dir = os.path.join( + base_cache_dir, + "torch_compile_cache", + cache_hash, + ) + + os.makedirs(cache_dir, exist_ok=True) + rank = 0 + dp_rank = 0 + local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}", model_tag) + os.makedirs(local_cache_dir, exist_ok=True) + self.compiler_manager.initialize_cache( + local_cache_dir, disable_cache=False, prefix="" + ) + compilation_counter.num_graphs_seen += 1 + + assert not self._called, "SGLangBackend can only be called once" + + self.graph = graph + self.configure_post_pass() + + self.split_gm, self.piecewise_graphs = split_graph( + graph, ["sglang.unified_attention_with_output"] + ) + + from torch._dynamo.utils import lazy_format_graph_code + + # depyf will hook lazy_format_graph_code and dump the graph + # for debugging, no need to print the graph here + lazy_format_graph_code("before split", self.graph) + lazy_format_graph_code("after split", self.split_gm) + + compilation_counter.num_piecewise_graphs_seen += len(self.piecewise_graphs) + + submod_names_to_compile = [ + item.submod_name + for item in self.piecewise_graphs + if not item.is_splitting_graph + ] + + PiecewiseCompileInterpreter( + self.split_gm, + submod_names_to_compile, + self.inductor_config, + self.graph_pool, + self.compile_config, + self, + ).run(*example_inputs) + + graph_path = os.path.join(local_cache_dir, "computation_graph.py") + if not os.path.exists(graph_path): + # code adapted from https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30 # noqa + # use `print_readable` because it can include submodules + src = ( + "from __future__ import annotations\nimport torch\n" + + self.split_gm.print_readable(print_output=False) + ) + src = src.replace("", "GraphModule") + with open(graph_path, "w") as f: + f.write(src) + + logger.debug("Computation graph saved to %s", graph_path) + + self._called = True + return self.split_gm diff --git a/python/sglang/srt/model_executor/compilation/compilation_config.py b/python/sglang/srt/model_executor/compilation/compilation_config.py new file mode 100644 index 000000000..7a8ef6436 --- /dev/null +++ b/python/sglang/srt/model_executor/compilation/compilation_config.py @@ -0,0 +1,19 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/compilation_config.py + +from typing import List + + +# TODO(Yuwei): support better compile config support +class CompilationConfig: + def __init__(self, capture_sizes: List[int]): + self.traced_files = set() + self.capture_sizes = capture_sizes + + def add_traced_file(self, file_path: str): + self.traced_files.add(file_path) + + def get_traced_files(self): + return self.traced_files + + def get_capture_sizes(self): + return self.capture_sizes diff --git a/python/sglang/srt/model_executor/compilation/compilation_counter.py b/python/sglang/srt/model_executor/compilation/compilation_counter.py new file mode 100644 index 000000000..e973f8f2f --- /dev/null +++ b/python/sglang/srt/model_executor/compilation/compilation_counter.py @@ -0,0 +1,47 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/compilation_counter.py + +import copy +import dataclasses +from contextlib import contextmanager + + +@dataclasses.dataclass +class CompilationCounter: + num_models_seen: int = 0 + num_graphs_seen: int = 0 + # including the splitting ops + num_piecewise_graphs_seen: int = 0 + # not including the splitting ops + num_piecewise_capturable_graphs_seen: int = 0 + num_backend_compilations: int = 0 + # Number of gpu_model_runner attempts to trigger CUDAGraphs capture + num_gpu_runner_capture_triggers: int = 0 + # Number of CUDAGraphs captured + num_cudagraph_captured: int = 0 + # InductorAdapter.compile calls + num_inductor_compiles: int = 0 + # EagerAdapter.compile calls + num_eager_compiles: int = 0 + # The number of time vLLM's compiler cache entry was updated + num_cache_entries_updated: int = 0 + # The number of standalone_compile compiled artifacts saved + num_compiled_artifacts_saved: int = 0 + # Number of times a model was loaded with CompilationLevel.DYNAMO_AS_IS + dynamo_as_is_count: int = 0 + + def clone(self) -> "CompilationCounter": + return copy.deepcopy(self) + + @contextmanager + def expect(self, **kwargs): + old = self.clone() + yield + for k, v in kwargs.items(): + assert getattr(self, k) - getattr(old, k) == v, ( + f"{k} not as expected, before it is {getattr(old, k)}" + f", after it is {getattr(self, k)}, " + f"expected diff is {v}" + ) + + +compilation_counter = CompilationCounter() diff --git a/python/sglang/srt/model_executor/compilation/compile.py b/python/sglang/srt/model_executor/compilation/compile.py new file mode 100644 index 000000000..dee7f0169 --- /dev/null +++ b/python/sglang/srt/model_executor/compilation/compile.py @@ -0,0 +1,210 @@ +import contextvars +import inspect +import logging +import os +import sys +import types +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Any, Callable, Optional, Union + +import torch + +from sglang.srt.model_executor.compilation.compilation_config import CompilationConfig + +logger = logging.getLogger(__name__) + +_COMPILE_ENABLED = contextvars.ContextVar("_COMPILE_ENABLED", default=False) + + +@contextmanager +def set_compiled(enabled: bool = True): + token = _COMPILE_ENABLED.set(enabled) + try: + yield + finally: + _COMPILE_ENABLED.reset(token) + + +@dataclass +class IntermediateTensors: + """For all pipeline stages except the last, we need to return the hidden + states and residuals to be sent to the next stage. This data structure + contains the hidden states and residuals for a request. + + Each stage also needs to handle its own finished_sending and + finished_recving in case of kv transfer. + """ + + tensors: dict[str, torch.Tensor] + # [req_ids] + finished_sending: Optional[set[str]] = None + finished_recving: Optional[set[str]] = None + + def __init__(self, tensors): + # manually define this function, so that + # Dynamo knows `IntermediateTensors()` comes from this file. + # Otherwise, dataclass will generate this function by evaluating + # a string, and we will lose the information about the source file. + self.tensors = tensors + + def __getitem__(self, key: Union[str, slice]): + if isinstance(key, str): + return self.tensors[key] + elif isinstance(key, slice): + return self.__class__({k: v[key] for k, v in self.tensors.items()}) + + def __setitem__(self, key: str, value: torch.Tensor): + self.tensors[key] = value + + def items(self): + return self.tensors.items() + + def __len__(self): + return len(self.tensors) + + def __eq__(self, other: object): + return isinstance(other, self.__class__) and self + + def __repr__(self) -> str: + return f"IntermediateTensors(tensors={self.tensors})" + + +def _normalize_dims(dims, ndim: int): + dims = [dims] if isinstance(dims, int) else list(dims) + return [d if d >= 0 else ndim + d for d in dims] + + +class _MaybeIntermediateTensors: + """Duck-typed check to support your IntermediateTensors without importing.""" + + def __init__(self, obj): + self.is_intermediate = hasattr(obj, "tensors") and isinstance( + getattr(obj, "tensors"), dict + ) + self.obj = obj + + +def _mark_dynamic_on_value(val, dims): + if isinstance(val, torch.Tensor): + torch._dynamo.mark_dynamic(val, _normalize_dims(dims, val.ndim)) + else: + mit = _MaybeIntermediateTensors(val) + if mit.is_intermediate: + for t in mit.obj.tensors.values(): + torch._dynamo.mark_dynamic(t, _normalize_dims(dims, t.ndim)) + # else: ignore (None or non-tensor) + + +def _infer_dynamic_arg_dims_from_annotations(forward_fn): + sig = inspect.signature(forward_fn) + dyn = {} + for name, p in sig.parameters.items(): + ann = p.annotation + # Accept torch.Tensor / Optional[torch.Tensor] / your IntermediateTensors types by name + if ( + ann is torch.Tensor + or getattr(getattr(ann, "__args__", [None])[0], "__name__", "") == "Tensor" + ): + dyn[name] = 0 + elif getattr(ann, "__name__", "") in ("IntermediateTensors",) or any( + getattr(a, "__name__", "") == "IntermediateTensors" + for a in getattr(ann, "__args__", []) + ): + dyn[name] = 0 + if not dyn: + raise ValueError("No dynamic dims inferred; pass dynamic_arg_dims explicitly.") + return dyn + + +def install_torch_compiled( + module: torch.nn.Module, + *, + dynamic_arg_dims: dict[str, Union[int, list[int]]] | None = None, + backend_factory: Optional[Callable[[torch.fx.GraphModule, list], Callable]] = None, + compile_config: CompilationConfig = None, + fullgraph: bool = True, + graph_pool: Any = None, +): + unbound_fwd = module.__class__.forward + if not callable(unbound_fwd): + raise TypeError("module.__class__.forward must be callable") + original_code = unbound_fwd.__code__ + + dyn_map = dynamic_arg_dims or _infer_dynamic_arg_dims_from_annotations(unbound_fwd) + + if backend_factory is None: + from sglang.srt.model_executor.compilation.backend import SGLangBackend + + backend_factory = lambda gm, ex: SGLangBackend(compile_config, graph_pool)( + gm, ex + ) + + compiled_codes: list[type(original_code)] = [] + state = {"compiled": False, "compiled_callable": None} + + def bytecode_hook(old_code, new_code): + if old_code is not original_code: + return + frame = sys._getframe() + while frame and frame.f_back: + frame = frame.f_back + if ( + frame.f_code.co_name == "_compile" + and os.path.basename(frame.f_code.co_filename) == "convert_frame.py" + ): + break + try: + dynamo_frame = frame.f_locals["frame"] + except Exception: + return + if dynamo_frame.f_code is not old_code: + return + if dynamo_frame.f_locals.get("self") is not module: + return + compiled_codes.append(new_code) + + torch._dynamo.convert_frame.register_bytecode_hook(bytecode_hook) + + def _ensure_compiled(self, *args, **kwargs): + """Compile on first use (with flag ON).""" + if state["compiled"]: + return + # Mark dynamic dims only when we are about to compile + sig = inspect.signature(unbound_fwd) + ba = sig.bind(self, *args, **kwargs) + ba.apply_defaults() + for name, dims in (dyn_map or {}).items(): + if name in ba.arguments: + val = ba.arguments[name] + if val is not None: + _mark_dynamic_on_value(val, dims) + + # Avoid cross-instance cache reuse + torch._dynamo.eval_frame.remove_from_cache(unbound_fwd.__code__) + + bound = types.MethodType(unbound_fwd, self) + compiled_callable = torch.compile( + bound, fullgraph=fullgraph, backend=backend_factory + ) + + # Trigger Dynamo so bytecode hook can capture + compiled_callable(*args, **kwargs) + + state["compiled"] = True + state["compiled_callable"] = compiled_callable + + def trampoline(self, *args, **kwargs): + use_compiled = _COMPILE_ENABLED.get() + if use_compiled: + if not state["compiled"]: + _ensure_compiled(self, *args, **kwargs) + + compiled_callable = state["compiled_callable"] + return compiled_callable(*args, **kwargs) + else: + # Explicitly run the original uncompiled forward + return unbound_fwd(self, *args, **kwargs) + + module.forward = types.MethodType(trampoline, module) + return module diff --git a/python/sglang/srt/model_executor/compilation/compiler_interface.py b/python/sglang/srt/model_executor/compilation/compiler_interface.py new file mode 100644 index 000000000..016703022 --- /dev/null +++ b/python/sglang/srt/model_executor/compilation/compiler_interface.py @@ -0,0 +1,479 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/compiler_interface.py + +import contextlib +import copy +import hashlib +import os +from contextlib import ExitStack +from typing import Any, Callable, Optional +from unittest.mock import patch + +import torch +import torch._inductor.compile_fx +import torch.fx as fx + +from sglang.srt.model_executor.compilation.compilation_counter import ( + compilation_counter, +) +from sglang.srt.model_executor.compilation.inductor_pass import pass_context + + +class CompilerInterface: + """ + The interface for a compiler that can be used by vLLM. + """ + + # The name of the compiler, e.g. inductor. + # This is a class-level attribute. + name: str + + def initialize_cache( + self, cache_dir: str, disable_cache: bool = False, prefix: str = "" + ): + """ + when the vLLM process uses `cache_dir` as the cache directory, + the compiler should initialize itself with the cache directory, + e.g. by re-directing its own cache directory to a sub-directory. + + prefix can be used in combination with cache_dir to figure out the base + cache directory, e.g. there're multiple parts of model being compiled, + but we want to share the same cache directory for all of them. + + e.g. + cache_dir = "/path/to/dir/backbone", prefix = "backbone" + cache_dir = "/path/to/dir/eagle_head", prefix = "eagle_head" + """ + pass + + def compute_hash(self) -> str: + """ + Gather all the relevant information from the vLLM config, + to compute a hash so that we can cache the compiled model. + + See [`VllmConfig.compute_hash`][vllm.config.VllmConfig.compute_hash] + to check what information + is already considered by default. This function should only + consider the information that is specific to the compiler. + """ + return "" + + def compile( + self, + graph: fx.GraphModule, + example_inputs: list[Any], + compiler_config: dict[str, Any], + runtime_shape: Optional[int] = None, + key: Optional[str] = None, + ) -> tuple[Optional[Callable], Optional[Any]]: + """ + Compile the graph with the given example inputs and compiler config, + with a runtime shape. If the `runtime_shape` is None, it means + the `example_inputs` have a dynamic shape. Otherwise, the + `runtime_shape` specifies the shape of the inputs. Right now we only + support one variable shape for all inputs, which is the batchsize + (number of tokens) during inference. + + Dynamo will make sure `graph(*example_inputs)` is valid. + + The function should return a compiled callable function, as well as + a handle that can be used to directly load the compiled function. + + The handle should be a plain Python object, preferably a string or a + file path for readability. + + If the compiler doesn't support caching, it should return None for the + handle. If the compiler fails to compile the graph, it should return + None for the compiled function as well. + + `key` is required for StandaloneInductorAdapter, it specifies where to + save the compiled artifact. The compiled artifact gets saved to + `cache_dir/key`. + """ + return None, None + + def load( + self, + handle: Any, + graph: fx.GraphModule, + example_inputs: list[Any], + graph_index: int, + runtime_shape: Optional[int] = None, + ) -> Callable: + """ + Load the compiled function from the handle. + Raises an error if the handle is invalid. + + The handle is the second return value of the `compile` function. + """ + raise NotImplementedError("caching is not supported") + + +def get_inductor_factors() -> list[Any]: + factors: list[Any] = [] + # summarize system state + from torch._inductor.codecache import CacheBase + + system_factors = CacheBase.get_system() + factors.append(system_factors) + + # summarize pytorch state + from torch._inductor.codecache import torch_key + + torch_factors = torch_key() + factors.append(torch_factors) + return factors + + +class AlwaysHitShapeEnv: + """ + Why do we need this class: + + For normal `torch.compile` usage, every compilation will have + one Dynamo bytecode compilation and one Inductor compilation. + The Inductor compilation happens under the context of the + Dynamo bytecode compilation, and that context is used to + determine the dynamic shape information, etc. + + For our use case, we only run Dynamo bytecode compilation once, + and run Inductor compilation multiple times with different shapes + plus a general shape. The compilation for specific shapes happens + outside of the context of the Dynamo bytecode compilation. At that + time, we don't have shape environment to provide to Inductor, and + it will fail the Inductor code cache lookup. + + By providing a dummy shape environment that always hits, we can + make the Inductor code cache lookup always hit, and we can + compile the graph for different shapes as needed. + + The following dummy methods are obtained by trial-and-error + until it works. + """ + + def __init__(self) -> None: + self.guards: list[Any] = [] + + def evaluate_guards_expression(self, *args, **kwargs): + return True + + def get_pruned_guards(self, *args, **kwargs): + return [] + + def produce_guards_expression(self, *args, **kwargs): + return "" + + +class InductorAdaptor(CompilerInterface): + """ + The adaptor for the Inductor compiler, version 2.5, 2.6, 2.7. + """ + + name = "inductor" + + def compute_hash(self) -> str: + factors = get_inductor_factors() + hash_str = hashlib.md5( + str(factors).encode(), usedforsecurity=False + ).hexdigest()[:10] + return hash_str + + def initialize_cache( + self, cache_dir: str, disable_cache: bool = False, prefix: str = "" + ): + self.cache_dir = cache_dir + self.prefix = prefix + self.base_cache_dir = cache_dir[: -len(prefix)] if prefix else cache_dir + if disable_cache: + return + # redirect the cache directory to a sub-directory + # set flags so that Inductor and Triton store their cache + # in the cache_dir, then users only need to copy the cache_dir + # to another machine to reuse the cache. + inductor_cache = os.path.join(self.base_cache_dir, "inductor_cache") + os.makedirs(inductor_cache, exist_ok=True) + os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache + triton_cache = os.path.join(self.base_cache_dir, "triton_cache") + os.makedirs(triton_cache, exist_ok=True) + os.environ["TRITON_CACHE_DIR"] = triton_cache + + def compile( + self, + graph: fx.GraphModule, + example_inputs: list[Any], + compiler_config: dict[str, Any], + runtime_shape: Optional[int] = None, + key: Optional[str] = None, + ) -> tuple[Optional[Callable], Optional[Any]]: + compilation_counter.num_inductor_compiles += 1 + from torch._inductor.compile_fx import compile_fx + + current_config = {} + if compiler_config is not None: + current_config.update(compiler_config) + + # disable remote cache + current_config["fx_graph_cache"] = True + current_config["fx_graph_remote_cache"] = False + + set_inductor_config(current_config, runtime_shape) + + # inductor can inplace modify the graph, so we need to copy it + # see https://github.com/pytorch/pytorch/issues/138980 + graph = copy.deepcopy(graph) + + # it's the first time we compile this graph + # the assumption is that we don't have nested Inductor compilation. + # compiled_fx_graph_hash will only be called once, and we can hook + # it to get the hash of the compiled graph directly. + + hash_str, file_path = None, None + from torch._inductor.codecache import FxGraphCache, compiled_fx_graph_hash + + if torch.__version__.startswith("2.5"): + original_load = FxGraphCache.load + original_load_name = "torch._inductor.codecache.FxGraphCache.load" + + def hijack_load(*args, **kwargs): + inductor_compiled_graph = original_load(*args, **kwargs) + nonlocal file_path + compiled_fn = inductor_compiled_graph.current_callable + file_path = compiled_fn.__code__.co_filename # noqa + if not file_path.startswith(self.base_cache_dir): + # hooked in the align_inputs_from_check_idxs function + # in torch/_inductor/utils.py + for cell in compiled_fn.__closure__: + if not callable(cell.cell_contents): + continue + if cell.cell_contents.__code__.co_filename.startswith( + self.base_cache_dir + ): + # this is the real file path compiled from Inductor + file_path = cell.cell_contents.__code__.co_filename + break + return inductor_compiled_graph + + hijacked_compile_fx_inner = ( + torch._inductor.compile_fx.compile_fx_inner + ) # noqa + elif torch.__version__ >= "2.6": + # function renamed in 2.6 + original_load_name = None + + def hijacked_compile_fx_inner(*args, **kwargs): + output = torch._inductor.compile_fx.compile_fx_inner(*args, **kwargs) + nonlocal hash_str + inductor_compiled_graph = output + if inductor_compiled_graph is not None: + nonlocal file_path + compiled_fn = inductor_compiled_graph.current_callable + file_path = compiled_fn.__code__.co_filename # noqa + if not file_path.startswith(self.base_cache_dir): + # hooked in the align_inputs_from_check_idxs function + # in torch/_inductor/utils.py + for cell in compiled_fn.__closure__: + if not callable(cell.cell_contents): + continue + code = cell.cell_contents.__code__ + if code.co_filename.startswith(self.base_cache_dir): + # this is the real file path + # compiled from Inductor + file_path = code.co_filename + break + hash_str = inductor_compiled_graph._fx_graph_cache_key + return output + + def hijack_compiled_fx_graph_hash(*args, **kwargs): + out = compiled_fx_graph_hash(*args, **kwargs) + nonlocal hash_str + hash_str = out[0] + return out + + def _check_can_cache(*args, **kwargs): + # no error means it can be cached. + # Inductor refuses to cache the graph outside of Dynamo + # tracing context, and also disables caching for graphs + # with high-order ops. + # For vLLM, in either case, we want to cache the graph. + # see https://github.com/pytorch/pytorch/blob/9f5ebf3fc609105a74eab4ccc24932d6353ff566/torch/_inductor/codecache.py#L1221 # noqa + return + + def _get_shape_env() -> AlwaysHitShapeEnv: + return AlwaysHitShapeEnv() + + with ExitStack() as stack: + # hijack to get the compiled graph itself + if original_load_name is not None: + stack.enter_context(patch(original_load_name, hijack_load)) + + # for hijacking the hash of the compiled graph + stack.enter_context( + patch( + "torch._inductor.codecache.compiled_fx_graph_hash", + hijack_compiled_fx_graph_hash, + ) + ) + + # for providing a dummy shape environment + stack.enter_context( + patch( + "torch._inductor.codecache.FxGraphCache._get_shape_env", + _get_shape_env, + ) + ) + + from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache + + # torch 2.8+ on main uses _get_shape_env in AOTAutogradCache + if hasattr(AOTAutogradCache, "_get_shape_env"): + stack.enter_context( + patch( + "torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env", + _get_shape_env, + ) + ) + + # for forcing the graph to be cached + stack.enter_context( + patch( + "torch._inductor.codecache.FxGraphCache._check_can_cache", + _check_can_cache, + ) + ) + + # Dynamo metrics context, see method for more details. + stack.enter_context(self.metrics_context()) + + # Disable remote caching. When these are on, on remote cache-hit, + # the monkey-patched functions never actually get called. + # vLLM today assumes and requires the monkey-patched functions to + # get hit. + # TODO(zou3519): we're going to replace this all with + # standalone_compile sometime. + + stack.enter_context( + torch._inductor.config.patch(fx_graph_remote_cache=False) + ) + # InductorAdaptor (unfortunately) requires AOTAutogradCache + # to be turned off to run. It will fail to acquire the hash_str + # and error if not. + # StandaloneInductorAdaptor (PyTorch 2.8+) fixes this problem. + stack.enter_context( + torch._functorch.config.patch(enable_autograd_cache=False) + ) + stack.enter_context( + torch._functorch.config.patch(enable_remote_autograd_cache=False) + ) + + with pass_context(runtime_shape): + compiled_graph = compile_fx( + graph, + example_inputs, + inner_compile=hijacked_compile_fx_inner, + config_patches=current_config, + ) + return compiled_graph, (hash_str, file_path) + + def load( + self, + handle: Any, + graph: fx.GraphModule, + example_inputs: list[Any], + graph_index: int, + runtime_shape: Optional[int] = None, + ) -> Callable: + assert isinstance(handle, tuple) + assert isinstance(handle[0], str) + assert isinstance(handle[1], str) + hash_str = handle[0] + + from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache + from torch._inductor.codecache import FxGraphCache + + with ExitStack() as exit_stack: + exit_stack.enter_context( + patch( + "torch._inductor.codecache.FxGraphCache._get_shape_env", + lambda *args, **kwargs: AlwaysHitShapeEnv(), + ) + ) + # torch 2.8+ on main uses _get_shape_env in AOTAutogradCache + if hasattr(AOTAutogradCache, "_get_shape_env"): + exit_stack.enter_context( + patch( + "torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env", + lambda *args, **kwargs: AlwaysHitShapeEnv(), + ) + ) + + # Dynamo metrics context, see method for more details. + exit_stack.enter_context(self.metrics_context()) + + if torch.__version__.startswith("2.5"): + inductor_compiled_graph = FxGraphCache._lookup_graph( + hash_str, example_inputs, True, False + ) + assert inductor_compiled_graph is not None, ( + "Inductor cache lookup failed. Please remove" + f"the cache directory and try again." # noqa + ) + elif torch.__version__ >= "2.6": + from torch._inductor.output_code import CompiledFxGraphConstantsWithGm + + constants = CompiledFxGraphConstantsWithGm(graph) + inductor_compiled_graph, _ = FxGraphCache._lookup_graph( + hash_str, example_inputs, True, None, constants + ) + assert inductor_compiled_graph is not None, ( + "Inductor cache lookup failed. Please remove" + f"the cache directory and try again." # noqa + ) + + # Inductor calling convention (function signature): + # f(list) -> tuple + # Dynamo calling convention (function signature): + # f(*args) -> Any + + # need to know if the graph returns a tuple + from torch._inductor.compile_fx import graph_returns_tuple + + returns_tuple = graph_returns_tuple(graph) + + # this is the callable we return to Dynamo to run + def compiled_graph(*args): + # convert args to list + list_args = list(args) + graph_output = inductor_compiled_graph(list_args) + # unpack the tuple if needed + if returns_tuple: + return graph_output + else: + return graph_output[0] + + return compiled_graph + + def metrics_context(self) -> contextlib.AbstractContextManager: + """ + This method returns the Dynamo metrics context (if it exists, + otherwise a null context). It is used by various compile components. + Present in torch>=2.6, it's used inside FxGraphCache in + torch==2.6 (but not after). It might also be used in various other + torch.compile internal functions. + + Because it is re-entrant, we always set it (even if entering via Dynamo + and the context was already entered). We might want to revisit if it + should be set at a different level of compilation. + + This is likely a bug in PyTorch: public APIs should not rely on + manually setting up internal contexts. But we also rely on non-public + APIs which might not provide these guarantees. + """ + import torch._dynamo.utils + + return torch._dynamo.utils.get_metrics_context() + + +def set_inductor_config(config, runtime_shape): + if isinstance(runtime_shape, int): + # for a specific batchsize, tuning triton kernel parameters + # can be beneficial + config["max_autotune"] = True + config["coordinate_descent_tuning"] = True diff --git a/python/sglang/srt/model_executor/compilation/cuda_piecewise_backend.py b/python/sglang/srt/model_executor/compilation/cuda_piecewise_backend.py new file mode 100644 index 000000000..22f35b3bc --- /dev/null +++ b/python/sglang/srt/model_executor/compilation/cuda_piecewise_backend.py @@ -0,0 +1,230 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/cuda_piecewise_backend.py + +import dataclasses +import logging +from contextlib import ExitStack +from typing import Any, Callable, Optional, Union +from unittest.mock import patch + +import torch +import torch.fx as fx + +import sglang.srt.model_executor.compilation.weak_ref_tensor_jit +from sglang.srt.model_executor.compilation.compilation_config import CompilationConfig +from sglang.srt.model_executor.compilation.compilation_counter import ( + compilation_counter, +) + +logger = logging.getLogger(__name__) + + +def weak_ref_tensor(tensor: Any) -> Any: + """ + Create a weak reference to a tensor. + The new tensor will share the same data as the original tensor, + but will not keep the original tensor alive. + """ + if isinstance(tensor, torch.Tensor): + # TODO(yuwei): introduce weak_ref_tensor from sgl_kernel + return torch.ops.jit_weak_ref_tensor.weak_ref_tensor(tensor) + return tensor + + +def weak_ref_tensors( + tensors: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]] +) -> Union[torch.Tensor, list[Any], tuple[Any], Any]: + """ + Convenience function to create weak references to tensors, + for single tensor, list of tensors or tuple of tensors. + """ + if isinstance(tensors, torch.Tensor): + return weak_ref_tensor(tensors) + if isinstance(tensors, list): + return [weak_ref_tensor(t) for t in tensors] + if isinstance(tensors, tuple): + return tuple(weak_ref_tensor(t) for t in tensors) + raise ValueError("Invalid type for tensors") + + +@dataclasses.dataclass +class ConcreteSizeEntry: + runtime_shape: int + need_to_compile: bool # the size is in compile_sizes + use_cudagraph: bool # the size is in cudagraph_capture_sizes + + compiled: bool = False + runnable: Callable = None # type: ignore + num_finished_warmup: int = 0 + cudagraph: Optional[torch.cuda.CUDAGraph] = None + output: Optional[Any] = None + + # for cudagraph debugging, track the input addresses + # during capture, and check if they are the same during replay + input_addresses: Optional[list[int]] = None + + +class CUDAPiecewiseBackend: + + def __init__( + self, + graph: fx.GraphModule, + compile_config: CompilationConfig, + inductor_config: dict[str, Any], + graph_pool: Any, + piecewise_compile_index: int, + total_piecewise_compiles: int, + sym_shape_indices: list[int], + compiled_graph_for_general_shape: Callable, + sglang_backend, + ): + """ + The backend for piecewise compilation. + It mainly handles the compilation and cudagraph capturing. + + We will compile `self.graph` once for the general shape, + and then compile for different shapes specified in + `compilation_config.compile_sizes`. + + Independently, we will capture cudagraph for different shapes. + + If a shape needs both compilation and cudagraph, we will + compile it first, and then capture cudagraph. + """ + self.graph = graph + self.inductor_config = inductor_config + self.graph_pool = graph_pool + self.piecewise_compile_index = piecewise_compile_index + self.total_piecewise_compiles = total_piecewise_compiles + self.sglang_backend = sglang_backend + + self.is_first_graph = piecewise_compile_index == 0 + self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1 + + self.compile_sizes: set[int] = set([]) + self.compile_config = compile_config + self.cudagraph_capture_sizes: set[int] = set(compile_config.get_capture_sizes()) + + self.first_run_finished = False + + self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa + + self.sym_shape_indices = sym_shape_indices + + self.is_debugging_mode = True + + # the entries for different shapes that we need to either + # compile or capture cudagraph + self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {} + + # to_be_compiled_sizes tracks the remaining sizes to compile, + # and updates during the compilation process, so we need to copy it + self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy() + for shape in self.compile_sizes.union(self.cudagraph_capture_sizes): + self.concrete_size_entries[shape] = ConcreteSizeEntry( + runtime_shape=shape, + need_to_compile=shape in self.compile_sizes, + use_cudagraph=shape in self.cudagraph_capture_sizes, + ) + + def check_for_ending_compilation(self): + if self.is_last_graph and not self.to_be_compiled_sizes: + # no specific sizes to compile + # save the hash of the inductor graph for the next run + self.sglang_backend.compiler_manager.save_to_file() + + def __call__(self, *args) -> Any: + if not self.first_run_finished: + self.first_run_finished = True + self.check_for_ending_compilation() + return self.compiled_graph_for_general_shape(*args) + runtime_shape = args[self.sym_shape_indices[0]] + if runtime_shape not in self.concrete_size_entries: + # we don't need to do anything for this shape + return self.compiled_graph_for_general_shape(*args) + + entry = self.concrete_size_entries[runtime_shape] + + if entry.runnable is None: + entry.runnable = self.compiled_graph_for_general_shape + + if entry.need_to_compile and not entry.compiled: + entry.compiled = True + self.to_be_compiled_sizes.remove(runtime_shape) + # args are real arguments + entry.runnable = self.sglang_backend.compiler_manager.compile( + self.graph, + args, + self.inductor_config, + graph_index=self.piecewise_compile_index, + num_graphs=self.total_piecewise_compiles, + runtime_shape=runtime_shape, + ) + + # finished compilations for all required shapes + if self.is_last_graph and not self.to_be_compiled_sizes: + self.check_for_ending_compilation() + + # Skip CUDA graphs if this entry doesn't use them OR + # if we're supposed to skip them globally + # skip_cuda_graphs = get_forward_context().skip_cuda_graphs + # if not entry.use_cudagraph or skip_cuda_graphs: + # return entry.runnable(*args) + + if entry.cudagraph is None: + if entry.num_finished_warmup < 1: # noqa + entry.num_finished_warmup += 1 + return entry.runnable(*args) + + input_addresses = [ + x.data_ptr() for x in args if isinstance(x, torch.Tensor) + ] + entry.input_addresses = input_addresses + cudagraph = torch.cuda.CUDAGraph() + + with ExitStack() as stack: + if not self.is_first_graph: + # during every model forward, we will capture + # many pieces of cudagraphs (roughly one per layer). + # running gc again and again across layers will + # make the cudagraph capture very slow. + # therefore, we only run gc for the first graph, + # and disable gc for the rest of the graphs. + stack.enter_context(patch("gc.collect", lambda: None)) + stack.enter_context(patch("torch.cuda.empty_cache", lambda: None)) + + # mind-exploding: carefully manage the reference and memory. + with torch.cuda.graph(cudagraph, pool=self.graph_pool): + # `output` is managed by pytorch's cudagraph pool + output = entry.runnable(*args) + if self.is_last_graph: + # by converting it to weak ref, + # the original `output` will immediately be released + # to save memory. It is only safe to do this for + # the last graph, because the output of the last graph + # will not be used by any other cuda graph. + output = weak_ref_tensors(output) + + # here we always use weak ref for the output + # to save memory + entry.output = weak_ref_tensors(output) + entry.cudagraph = cudagraph + + compilation_counter.num_cudagraph_captured += 1 + + # important: we need to return the output, rather than + # the weak ref of the output, so that pytorch can correctly + # manage the memory during cuda graph capture + return output + + if self.is_debugging_mode: + # check if the input addresses are the same + new_input_addresses = [ + x.data_ptr() for x in args if isinstance(x, torch.Tensor) + ] + assert new_input_addresses == entry.input_addresses, ( + "Input addresses for cudagraphs are different during replay." + f" Expected {entry.input_addresses}, got {new_input_addresses}" + ) + + entry.cudagraph.replay() + return entry.output diff --git a/python/sglang/srt/model_executor/compilation/fix_functionalization.py b/python/sglang/srt/model_executor/compilation/fix_functionalization.py new file mode 100644 index 000000000..bd18173ae --- /dev/null +++ b/python/sglang/srt/model_executor/compilation/fix_functionalization.py @@ -0,0 +1,134 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/fix_functionalization.py + +import logging +import operator +from collections.abc import Iterable +from typing import Optional, Union + +import torch +from torch._higher_order_ops.auto_functionalize import auto_functionalized + +from sglang.srt.model_executor.compilation.fx_utils import is_func +from sglang.srt.model_executor.compilation.inductor_pass import SGLangInductorPass + +logger = logging.getLogger(__name__) + + +class FixFunctionalizationPass(SGLangInductorPass): + """ + This pass defunctionalizes certain nodes to avoid redundant tensor copies. + After this pass, DCE (dead-code elimination) should never be run, + as de-functionalized nodes may appear as dead code. + + To add new nodes to defunctionalize, add to the if-elif chain in __call__. + """ + + def __call__(self, graph: torch.fx.Graph): + self.begin() + self.dump_graph(graph, "before_fix_functionalization") + + self.nodes_to_remove: list[torch.fx.Node] = [] + count = 0 + for node in graph.nodes: + if not is_func(node, auto_functionalized): + continue # Avoid deep if-elif nesting + count += 1 + + self.dump_graph(graph, "before_fix_functionalization_cleanup") + + # Remove the nodes all at once + count_removed = len(self.nodes_to_remove) + for node in self.nodes_to_remove: + graph.erase_node(node) + + logger.debug( + "De-functionalized %s nodes, removed %s nodes", count, count_removed + ) + self.dump_graph(graph, "after_fix_functionalization") + self.end_and_log() + + def _remove(self, node_or_nodes: Union[torch.fx.Node, Iterable[torch.fx.Node]]): + """ + Stage a node (or nodes) for removal at the end of the pass. + """ + if isinstance(node_or_nodes, torch.fx.Node): + self.nodes_to_remove.append(node_or_nodes) + else: + self.nodes_to_remove.extend(node_or_nodes) + + def defunctionalize( + self, + graph: torch.fx.Graph, + node: torch.fx.Node, + mutated_args: dict[int, Union[torch.fx.Node, str]], + args: Optional[tuple[Union[torch.fx.Node, str], ...]] = None, + ): + """ + De-functionalize a node by replacing it with a call to the original. + It also replaces the getitem users with the mutated arguments. + See replace_users_with_mutated_args and insert_defunctionalized. + """ + self.replace_users_with_mutated_args(node, mutated_args) + self.insert_defunctionalized(graph, node, args=args) + self._remove(node) + + def replace_users_with_mutated_args( + self, node: torch.fx.Node, mutated_args: dict[int, Union[torch.fx.Node, str]] + ): + """ + Replace all getitem users of the auto-functionalized node with the + mutated arguments. + :param node: The auto-functionalized node + :param mutated_args: The mutated arguments, indexed by getitem index. + If the value of an arg is a string, `node.kwargs[arg]` is used. + """ + for idx, user in self.getitem_users(node).items(): + arg = mutated_args[idx] + arg = node.kwargs[arg] if isinstance(arg, str) else arg + user.replace_all_uses_with(arg) + self._remove(user) + + def getitem_users(self, node: torch.fx.Node) -> dict[int, torch.fx.Node]: + """ + Returns the operator.getitem users of the auto-functionalized node, + indexed by the index they are getting. + """ + users = {} + for user in node.users: + if is_func(user, operator.getitem): + idx = user.args[1] + users[idx] = user + return users + + def insert_defunctionalized( + self, + graph: torch.fx.Graph, + node: torch.fx.Node, + args: Optional[tuple[Union[torch.fx.Node, str], ...]] = None, + ): + """ + Insert a new defunctionalized node into the graph before node. + If one of the kwargs is 'out', provide args directly, + as node.kwargs cannot be used. + See https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 + + :param graph: Graph to insert the defunctionalized node into + :param node: The auto-functionalized node to defunctionalize + :param args: If we cannot use kwargs, specify args directly. + If an arg is a string, `node.kwargs[arg]` is used. + """ # noqa: E501 + assert is_func( + node, auto_functionalized + ), f"node must be auto-functionalized, is {node} instead" + + # Create a new call to the original function + with graph.inserting_before(node): + function = node.args[0] + if args is None: + graph.call_function(function, kwargs=node.kwargs) + else: + # Args passed as strings refer to items in node.kwargs + args = tuple( + node.kwargs[arg] if isinstance(arg, str) else arg for arg in args + ) + graph.call_function(function, args=args) diff --git a/python/sglang/srt/model_executor/compilation/fx_utils.py b/python/sglang/srt/model_executor/compilation/fx_utils.py new file mode 100644 index 000000000..b2e863e68 --- /dev/null +++ b/python/sglang/srt/model_executor/compilation/fx_utils.py @@ -0,0 +1,83 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/fx_utils.py + +import operator +from collections.abc import Iterable, Iterator +from typing import Optional + +from torch import fx +from torch._higher_order_ops.auto_functionalize import auto_functionalized +from torch._ops import OpOverload + + +def is_func(node: fx.Node, target) -> bool: + return node.op == "call_function" and node.target == target + + +def is_auto_func(node: fx.Node, op: OpOverload) -> bool: + return is_func(node, auto_functionalized) and node.args[0] == op + + +# Returns the first specified node with the given op (if it exists) +def find_specified_fn_maybe( + nodes: Iterable[fx.Node], op: OpOverload +) -> Optional[fx.Node]: + for node in nodes: + if node.target == op: + return node + return None + + +# Returns the first specified node with the given op +def find_specified_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node: + node = find_specified_fn_maybe(nodes, op) + assert node is not None, f"Could not find {op} in nodes {nodes}" + return node + + +# Returns the first auto_functionalized node with the given op (if it exists) +def find_auto_fn_maybe(nodes: Iterable[fx.Node], op: OpOverload) -> Optional[fx.Node]: + for node in nodes: + if is_func(node, auto_functionalized) and node.args[0] == op: # noqa + return node + return None + + +# Returns the first auto_functionalized node with the given op +def find_auto_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node: + node = find_auto_fn_maybe(nodes, op) + assert node is not None, f"Could not find {op} in nodes {nodes}" + return node + + +# Returns the getitem node that extracts the idx-th element from node +# (if it exists) +def find_getitem_maybe(node: fx.Node, idx: int) -> Optional[fx.Node]: + for user in node.users: + if is_func(user, operator.getitem) and user.args[1] == idx: + return user + return None + + +# Returns the getitem node that extracts the idx-th element from node +def find_getitem(node: fx.Node, idx: int) -> fx.Node: + ret = find_getitem_maybe(node, idx) + assert ret is not None, f"Could not find getitem {idx} in node {node}" + return ret + + +# An auto-functionalization-aware utility for finding nodes with a specific op +def find_op_nodes(op: OpOverload, graph: fx.Graph) -> Iterator[fx.Node]: + if not op._schema.is_mutable: + yield from graph.find_nodes(op="call_function", target=op) + + for n in graph.find_nodes(op="call_function", target=auto_functionalized): + if n.args[0] == op: + yield n + + +# Asserts that the node only has one user and returns it +# Even if a node has only 1 user, it might share storage with another node, +# which might need to be taken into account. +def get_only_user(node: fx.Node) -> fx.Node: + assert len(node.users) == 1 + return next(iter(node.users)) diff --git a/python/sglang/srt/model_executor/compilation/inductor_pass.py b/python/sglang/srt/model_executor/compilation/inductor_pass.py new file mode 100644 index 000000000..acbde65bf --- /dev/null +++ b/python/sglang/srt/model_executor/compilation/inductor_pass.py @@ -0,0 +1,140 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/inductor_pass.py + +import hashlib +import inspect +import json +import logging +import time +import types +from contextlib import contextmanager +from typing import Any, Callable, Optional, Union + +import torch +from torch import fx +from torch._dynamo.utils import lazy_format_graph_code +from torch._inductor.custom_graph_pass import CustomGraphPass + +logger = logging.getLogger(__name__) + +_pass_context = None + + +class PassContext: + + def __init__(self, runtime_shape: Optional[int]): + self.runtime_shape = runtime_shape + + +def get_pass_context() -> PassContext: + """Get the current pass context.""" + assert _pass_context is not None + return _pass_context + + +@contextmanager +def pass_context(runtime_shape: Optional[int]): + """A context manager that stores the current pass context, + usually it is a list of sizes to specialize. + """ + global _pass_context + prev_context = _pass_context + _pass_context = PassContext(runtime_shape) + try: + yield + finally: + _pass_context = prev_context + + +class InductorPass(CustomGraphPass): + """ + A custom graph pass that uses a hash of its source as the UUID. + This is defined as a convenience and should work in most cases. + """ + + def uuid(self) -> Any: + """ + Provide a unique identifier for the pass, used in Inductor code cache. + This should depend on the pass implementation, so that changes to the + pass result in recompilation. + By default, the object source is hashed. + """ + return InductorPass.hash_source(self) + + @staticmethod + def hash_source(*srcs: Union[str, Any]): + """ + Utility method to hash the sources of functions or objects. + :param srcs: strings or objects to add to the hash. + Objects and functions have their source inspected. + :return: + """ + hasher = hashlib.sha256() + for src in srcs: + if isinstance(src, str): + src_str = src + elif isinstance(src, types.FunctionType): + src_str = inspect.getsource(src) + else: + src_str = inspect.getsource(src.__class__) + hasher.update(src_str.encode("utf-8")) + return hasher.hexdigest() + + @staticmethod + def hash_dict(dict_: dict[Any, Any]): + """ + Utility method to hash a dictionary, can alternatively be used for uuid. + :return: A sha256 hash of the json rep of the dictionary. + """ + encoded = json.dumps(dict_, sort_keys=True).encode("utf-8") + return hashlib.sha256(encoded).hexdigest() + + def is_applicable_for_shape(self, shape: Optional[int]): + return True + + +class CallableInductorPass(InductorPass): + """ + This class is a wrapper for a callable that automatically provides an + implementation of the UUID. + """ + + def __init__( + self, callable: Callable[[fx.Graph], None], uuid: Optional[Any] = None + ): + self.callable = callable + self._uuid = self.hash_source(callable) if uuid is None else uuid + + def __call__(self, graph: torch.fx.Graph): + self.callable(graph) + + def uuid(self) -> Any: + return self._uuid + + +class SGLangInductorPass(InductorPass): + + def __init__( + self, + ): + self.pass_name = self.__class__.__name__ + + def dump_graph(self, graph: torch.fx.Graph, stage: str): + lazy_format_graph_code(stage, graph.owning_module) + + def begin(self): + self._start_time = time.perf_counter_ns() + + def end_and_log(self): + self._end_time = time.perf_counter_ns() + duration_ms = float(self._end_time - self._start_time) / 1.0e6 + logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms) + + +class PrinterInductorPass(SGLangInductorPass): + + def __init__(self, name: str): + super().__init__() + self.name = name + + def __call__(self, graph: torch.fx.Graph): + self.dump_graph(graph, self.name) diff --git a/python/sglang/srt/model_executor/compilation/pass_manager.py b/python/sglang/srt/model_executor/compilation/pass_manager.py new file mode 100644 index 000000000..bc06a49ea --- /dev/null +++ b/python/sglang/srt/model_executor/compilation/pass_manager.py @@ -0,0 +1,68 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/pass_manager.py + +import logging + +from torch import fx as fx + +from sglang.srt.model_executor.compilation.fix_functionalization import ( + FixFunctionalizationPass, +) +from sglang.srt.model_executor.compilation.inductor_pass import ( + CustomGraphPass, + InductorPass, + SGLangInductorPass, + get_pass_context, +) + +logger = logging.getLogger(__name__) + + +class PostGradPassManager(CustomGraphPass): + """ + The pass manager for post-grad passes. + It handles configuration, adding custom passes, and running passes. + It supports uuid for the Inductor code cache. That includes torch<2.6 + support using pickling (in .inductor_pass.CustomGraphPass). + + The order of the post-grad post-passes is: + 1. passes (constructor parameter) + 2. default passes (NoopEliminationPass, FusionPass) + 3. config["post_grad_custom_post_pass"] (if it exists) + 4. fix_functionalization + This way, all passes operate on a functionalized graph. + """ + + def __init__(self): + self.passes: list[SGLangInductorPass] = [] + + def __call__(self, graph: fx.Graph): + shape = get_pass_context().runtime_shape + for pass_ in self.passes: + if pass_.is_applicable_for_shape(shape): + pass_(graph) + + # always run fix_functionalization last + self.fix_functionalization(graph) + + def configure( + self, + ): + self.pass_config = dict() + self.fix_functionalization = FixFunctionalizationPass() + + def add(self, pass_: InductorPass): + assert isinstance(pass_, InductorPass) + self.passes.append(pass_) + + def uuid(self): + """ + The PostGradPassManager is set as a custom pass in the Inductor and + affects compilation caching. Its uuid depends on the UUIDs of all + dependent passes and the pass config. See InductorPass for more info. + """ + pass_manager_uuid = "fshdakhsa" + state = {"pass_config": pass_manager_uuid, "passes": []} + for pass_ in self.passes: + state["passes"].append(pass_.uuid()) + state["passes"].append(self.fix_functionalization.uuid()) + return InductorPass.hash_dict(state) diff --git a/python/sglang/srt/model_executor/compilation/piecewise_context_manager.py b/python/sglang/srt/model_executor/compilation/piecewise_context_manager.py new file mode 100644 index 000000000..38d17a6df --- /dev/null +++ b/python/sglang/srt/model_executor/compilation/piecewise_context_manager.py @@ -0,0 +1,40 @@ +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Any, List, Optional + +from sglang.srt.model_executor.forward_batch_info import ForwardBatch + + +@dataclass +class ForwardContext: + def __init__(self): + self.forward_batch = None + self.attention_layer = None + + def set_forward_batch(self, forward_batch: ForwardBatch): + self.forward_batch = forward_batch + + def set_attention_layers(self, layers: List[Any]): + self.attention_layers = layers + + +_forward_context: Optional[ForwardContext] = None + + +def get_forward_context() -> Optional[ForwardContext]: + if _forward_context is None: + return None + return _forward_context + + +@contextmanager +def set_forward_context(forward_batch: ForwardBatch, attention_layers: List[Any]): + global _forward_context + prev_forward_context = _forward_context + _forward_context = ForwardContext() + _forward_context.set_forward_batch(forward_batch) + _forward_context.set_attention_layers(attention_layers) + try: + yield + finally: + _forward_context = prev_forward_context diff --git a/python/sglang/srt/model_executor/compilation/weak_ref_tensor.cpp b/python/sglang/srt/model_executor/compilation/weak_ref_tensor.cpp new file mode 100644 index 000000000..bf49367c8 --- /dev/null +++ b/python/sglang/srt/model_executor/compilation/weak_ref_tensor.cpp @@ -0,0 +1,28 @@ +// Adapted from: https://github.com/vllm-project/vllm/blob/main/csrc/ops.h + +#include +#include + +static at::Tensor weak_ref_tensor(at::Tensor &tensor) { + TORCH_CHECK(tensor.is_cuda(), "weak_ref_tensor expects a CUDA tensor"); + + void *data_ptr = tensor.data_ptr(); + std::vector sizes = tensor.sizes().vec(); + std::vector strides = tensor.strides().vec(); + + auto options = tensor.options(); + + auto new_tensor = torch::from_blob(data_ptr, sizes, strides, options); + + return new_tensor; +} + +TORCH_LIBRARY(jit_weak_ref_tensor, ops) { + ops.def("weak_ref_tensor(Tensor input) -> Tensor"); +} + +TORCH_LIBRARY_IMPL(jit_weak_ref_tensor, CUDA, ops) { + ops.impl("weak_ref_tensor", weak_ref_tensor); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} diff --git a/python/sglang/srt/model_executor/compilation/weak_ref_tensor_jit.py b/python/sglang/srt/model_executor/compilation/weak_ref_tensor_jit.py new file mode 100644 index 000000000..094393fb2 --- /dev/null +++ b/python/sglang/srt/model_executor/compilation/weak_ref_tensor_jit.py @@ -0,0 +1,16 @@ +import os + +import torch +from torch.utils.cpp_extension import load + +_abs_path = os.path.dirname(os.path.abspath(__file__)) + +load( + name="weak_ref_tensor_ext", + sources=[f"{_abs_path}/weak_ref_tensor.cpp"], + extra_cflags=["-O3"], +) + +x = torch.arange(12, device="cuda").reshape(3, 4) +y = torch.ops.jit_weak_ref_tensor.weak_ref_tensor(x) +print("alias:", x.data_ptr() == y.data_ptr()) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 5b1b9d22a..fea4a49ef 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -108,8 +108,15 @@ from sglang.srt.mem_cache.memory_pool import ( ) from sglang.srt.model_executor.cpu_graph_runner import CPUGraphRunner from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner -from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors +from sglang.srt.model_executor.forward_batch_info import ( + ForwardBatch, + ForwardMode, + PPProxyTensors, +) from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner +from sglang.srt.model_executor.piecewise_cuda_graph_runner import ( + PiecewiseCudaGraphRunner, +) from sglang.srt.model_loader import get_model from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader from sglang.srt.model_loader.remote_instance_weight_loader_utils import ( @@ -307,6 +314,26 @@ class ModelRunner: self._model_update_group = {} self._weights_send_group = {} + if ( + self.server_args.enable_piecewise_cuda_graph + and self.can_run_piecewise_cuda_graph() + ): + self.attention_layers = [] + for layer in self.model.model.layers: + if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "attn"): + self.attention_layers.append(layer.self_attn.attn) + if len(self.attention_layers) < self.model_config.num_hidden_layers: + # TODO(yuwei): support Non-Standard GQA + log_info_on_rank0( + logger, + "Disable piecewise CUDA graph because some layers do not apply Standard GQA", + ) + self.piecewise_cuda_graph_runner = None + else: + self.piecewise_cuda_graph_runner = PiecewiseCudaGraphRunner(self) + else: + self.piecewise_cuda_graph_runner = None + def initialize(self, min_per_gpu_memory: float): server_args = self.server_args @@ -692,6 +719,7 @@ class ModelRunner: pipeline_model_parallel_size=self.pp_size, expert_model_parallel_size=self.moe_ep_size, duplicate_tp_group=self.server_args.enable_pdmux, + torch_compile=self.server_args.enable_piecewise_cuda_graph, ) initialize_dp_attention( server_args=self.server_args, @@ -1411,6 +1439,27 @@ class ModelRunner: f"Use Sliding window memory pool. full_layer_tokens={self.full_max_total_num_tokens}, swa_layer_tokens={self.swa_max_total_num_tokens}" ) + def can_run_piecewise_cuda_graph(self): + if self.server_args.disable_cuda_graph: + log_info_on_rank0( + logger, "Disable piecewise CUDA graph because disable_cuda_graph is set" + ) + return False + if self.server_args.enable_torch_compile: + log_info_on_rank0( + logger, + "Disable piecewise CUDA graph because piecewise_cuda_graph has conflict with torch compile", + ) + return False + if self.pp_size > 1: + # TODO(yuwei): support PP + log_info_on_rank0( + logger, + "Disable piecewise CUDA graph because piecewise_cuda_graph does not support PP", + ) + return False + return True + def init_memory_pool( self, total_gpu_memory: int, @@ -1932,6 +1981,11 @@ class ModelRunner: kwargs["input_embeds"] = forward_batch.input_embeds.bfloat16() if not self.is_generation: kwargs["get_embedding"] = True + + if self.piecewise_cuda_graph_runner is not None: + if self.piecewise_cuda_graph_runner.can_run(forward_batch): + return self.piecewise_cuda_graph_runner.replay(forward_batch, **kwargs) + return self.model.forward( forward_batch.input_ids, forward_batch.positions, diff --git a/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py b/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py new file mode 100644 index 000000000..a5f3b1d54 --- /dev/null +++ b/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py @@ -0,0 +1,532 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Run the model with cuda graph and torch.compile.""" + +from __future__ import annotations + +import bisect +import gc +import logging +from contextlib import contextmanager +from typing import TYPE_CHECKING, Union + +import torch +import tqdm + +from sglang.srt.custom_op import CustomOp +from sglang.srt.distributed import get_tensor_model_parallel_rank +from sglang.srt.distributed.device_communicators.pynccl_allocator import ( + set_graph_pool_id, +) +from sglang.srt.distributed.parallel_state import graph_capture +from sglang.srt.layers.dp_attention import ( + DpPaddingMode, + get_attention_tp_rank, + get_attention_tp_size, + set_dp_buffer_len, +) +from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.layers.torchao_utils import save_gemlite_cache +from sglang.srt.model_executor.compilation.compilation_config import CompilationConfig +from sglang.srt.model_executor.compilation.compile import ( + install_torch_compiled, + set_compiled, +) +from sglang.srt.model_executor.compilation.piecewise_context_manager import ( + set_forward_context, +) +from sglang.srt.model_executor.forward_batch_info import ( + CaptureHiddenMode, + ForwardBatch, + ForwardMode, + PPProxyTensors, +) +from sglang.srt.two_batch_overlap import TboCudaGraphRunnerPlugin +from sglang.srt.utils import get_available_gpu_memory, log_info_on_rank0 + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from sglang.srt.model_executor.model_runner import ModelRunner + +# Detect whether the current forward pass is in capture mode +is_capture_mode = False + + +def get_is_capture_mode(): + return is_capture_mode + + +@contextmanager +def model_capture_mode(): + global is_capture_mode + is_capture_mode = True + + yield + + is_capture_mode = False + + +@contextmanager +def freeze_gc(enable_cudagraph_gc: bool): + """ + Optimize garbage collection during CUDA graph capture. + Clean up, then freeze all remaining objects from being included + in future collections if GC is disabled during capture. + """ + gc.collect() + should_freeze = not enable_cudagraph_gc + if should_freeze: + gc.freeze() + try: + yield + finally: + if should_freeze: + gc.unfreeze() + + +def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int): + for sub in model._modules.values(): + if isinstance(sub, CustomOp): + if reverse: + sub.leave_torch_compile() + else: + sub.enter_torch_compile(num_tokens=num_tokens) + if isinstance(sub, torch.nn.Module): + _to_torch(sub, reverse, num_tokens) + + +@contextmanager +def patch_model(model: torch.nn.Module): + try: + _to_torch(model, reverse=False, num_tokens=16) + yield model + finally: + _to_torch(model, reverse=True, num_tokens=16) + + +# Reuse this memory pool across all cuda graph runners. +global_graph_memory_pool = None + + +def get_global_graph_memory_pool(): + return global_graph_memory_pool + + +def set_global_graph_memory_pool(val): + global global_graph_memory_pool + global_graph_memory_pool = val + + +class PiecewiseCudaGraphRunner: + """A PiecewiseCudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile.""" + + def __init__(self, model_runner: ModelRunner): + # Parse args + self.model_runner = model_runner + self.device = model_runner.device + self.device_module = torch.get_device_module(self.device) + self.graphs = {} + self.output_buffers = {} + self.tp_size = model_runner.server_args.tp_size + self.dp_size = model_runner.server_args.dp_size + self.pp_size = model_runner.server_args.pp_size + + self.attn_tp_size = get_attention_tp_size() + self.attn_tp_rank = get_attention_tp_rank() + + assert ( + self.model_runner.server_args.piecewise_cuda_graph_tokens is not None + ), "piecewise_cuda_graph_tokens is not set" + self.compile_config = CompilationConfig( + self.model_runner.server_args.piecewise_cuda_graph_tokens + ) + + # Batch sizes to capture + self.capture_num_tokens = self.compile_config.get_capture_sizes() + log_info_on_rank0( + logger, f"Capture cuda graph num tokens {self.capture_num_tokens}" + ) + self.capture_forward_mode = ForwardMode.EXTEND + self.capture_hidden_mode = CaptureHiddenMode.NULL + + # If returning hidden states is enabled, set initial capture hidden mode to full to avoid double-capture on startup + if model_runner.server_args.enable_return_hidden_states: + self.capture_hidden_mode = CaptureHiddenMode.FULL + + # Attention backend + self.max_num_tokens = max(self.capture_num_tokens) + + # Graph inputs + with torch.device(self.device): + self.input_ids = torch.zeros((self.max_num_tokens,), dtype=torch.int64) + self.out_cache_loc = torch.zeros( + (self.max_num_tokens,), dtype=self._cache_loc_dtype() + ) + self.positions = torch.zeros((self.max_num_tokens,), dtype=torch.int64) + self.tbo_plugin = TboCudaGraphRunnerPlugin() + + self.attention_layers = self.model_runner.attention_layers + + if get_global_graph_memory_pool() is None: + set_global_graph_memory_pool(self.device_module.graph_pool_handle()) + # Set graph pool id globally to be able to use symmetric memory + set_graph_pool_id(get_global_graph_memory_pool()) + + with patch_model(self.model_runner.model.model) as patched_model: + install_torch_compiled( + patched_model, + fullgraph=True, + dynamic_arg_dims=None, + compile_config=self.compile_config, + graph_pool=get_global_graph_memory_pool(), + ) + + with set_compiled(True): + self.warmup_and_capture() + + # Capture + try: + with model_capture_mode(): + self.capture() + except RuntimeError as e: + raise Exception( + f"Capture cuda graph failed: {e}\n{PIECEWISE_CUDA_GRAPH_CAPTURE_FAILED_MSG}" + ) + + self.raw_num_tokens = 0 + + def warmup_and_capture(self): + num_tokens = 2 + with torch.device(self.device): + forward_batch = ForwardBatch( + forward_mode=ForwardMode.EXTEND, + batch_size=1, + input_ids=torch.randint(0, 100, (num_tokens,), device=self.device), + req_pool_indices=torch.arange(1, device=self.device), + seq_lens=torch.tensor([num_tokens], device=self.device), + next_token_logits_buffer=None, + orig_seq_lens=torch.tensor([num_tokens], device=self.device), + seq_lens_cpu=torch.tensor([num_tokens]), + req_to_token_pool=self.model_runner.req_to_token_pool, + token_to_kv_pool=self.model_runner.token_to_kv_pool, + attn_backend=self.model_runner.attn_backend, + out_cache_loc=torch.randint(0, 100, (num_tokens,), device=self.device), + seq_lens_sum=num_tokens, + encoder_lens=None, + return_logprob=False, + extend_seq_lens=torch.tensor([num_tokens], device=self.device), + extend_prefix_lens=torch.tensor([num_tokens], device=self.device), + extend_start_loc=torch.tensor([0], device=self.device), + extend_prefix_lens_cpu=torch.tensor([num_tokens]), + extend_seq_lens_cpu=torch.tensor([num_tokens]), + extend_logprob_start_lens_cpu=torch.tensor([num_tokens]), + positions=torch.arange(num_tokens, device=self.device), + global_num_tokens_gpu=None, + global_num_tokens_for_logprob_gpu=None, + dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(), + global_dp_buffer_len=None, + mrope_positions=None, + spec_algorithm=None, + spec_info=None, + capture_hidden_mode=CaptureHiddenMode.NULL, + num_token_non_padded=None, + global_forward_mode=ForwardMode.EXTEND, + lora_ids=None, + ) + + with set_forward_context(forward_batch, self.attention_layers): + _ = self.model_runner.model.forward( + forward_batch.input_ids, + forward_batch.positions, + forward_batch, + ) + + def _cache_loc_dtype(self): + return torch.int64 + + def can_run(self, forward_batch: ForwardBatch): + num_tokens = len(forward_batch.input_ids) + # TODO(yuwei): support return logprob + if forward_batch.return_logprob: + return False + if num_tokens <= self.max_num_tokens: + return True + return False + + def capture(self) -> None: + # Trigger CUDA graph capture for specific shapes. + # Capture the large shapes first so that the smaller shapes + # can reuse the memory pool allocated for the large shapes. + with freeze_gc( + self.model_runner.server_args.enable_cudagraph_gc + ), graph_capture() as graph_capture_context: + self.stream = graph_capture_context.stream + avail_mem = get_available_gpu_memory( + self.model_runner.device, + self.model_runner.gpu_id, + empty_cache=False, + ) + # Reverse the order to enable better memory sharing across cuda graphs. + capture_range = ( + tqdm.tqdm(list(reversed(self.capture_num_tokens))) + if get_tensor_model_parallel_rank() == 0 + else reversed(self.capture_num_tokens) + ) + for i, num_tokens in enumerate(capture_range): + if get_tensor_model_parallel_rank() == 0: + avail_mem = get_available_gpu_memory( + self.model_runner.device, + self.model_runner.gpu_id, + empty_cache=False, + ) + capture_range.set_description( + f"Capturing num tokens ({num_tokens=} {avail_mem=:.2f} GB)" + ) + + with set_compiled(True): + self.capture_one_batch_size(num_tokens) + + # Save gemlite cache after each capture + save_gemlite_cache() + + def capture_one_batch_size(self, num_tokens: int): + stream = self.stream + bs = 1 + + # Graph inputs + input_ids = self.input_ids[:num_tokens] + out_cache_loc = self.out_cache_loc[:num_tokens] + positions = self.positions[:num_tokens] + + # pipeline parallelism + if self.pp_size > 1: + pp_proxy_tensors = PPProxyTensors( + {k: v[:num_tokens] for k, v in self.pp_proxy_tensors.items()} + ) + + global_dp_buffer_len = None + + if self.model_runner.server_args.enable_lora: + # It is safe to capture CUDA graph using empty LoRA id, as the LoRA kernels will always be launched whenever + # `--enable-lora` is set to True (and return immediately if the LoRA id is empty for perf optimization). + lora_ids = [None] * bs + else: + lora_ids = None + + with torch.device(self.device): + forward_batch = ForwardBatch( + forward_mode=ForwardMode.EXTEND, + batch_size=bs, + input_ids=input_ids, + req_pool_indices=torch.arange(bs, device=self.device), + seq_lens=torch.tensor([num_tokens], device=self.device), + next_token_logits_buffer=None, + orig_seq_lens=torch.tensor([num_tokens], device=self.device), + seq_lens_cpu=torch.tensor([num_tokens]), + req_to_token_pool=self.model_runner.req_to_token_pool, + token_to_kv_pool=self.model_runner.token_to_kv_pool, + attn_backend=self.model_runner.attn_backend, + out_cache_loc=out_cache_loc, + seq_lens_sum=num_tokens, + encoder_lens=None, + return_logprob=False, + extend_seq_lens=torch.tensor([num_tokens], device=self.device), + extend_prefix_lens=torch.tensor([num_tokens], device=self.device), + extend_start_loc=torch.tensor([0], device=self.device), + extend_prefix_lens_cpu=torch.tensor([num_tokens]), + extend_seq_lens_cpu=torch.tensor([num_tokens]), + extend_logprob_start_lens_cpu=torch.tensor([num_tokens]), + positions=positions, + global_num_tokens_gpu=None, + global_num_tokens_for_logprob_gpu=None, + dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(), + global_dp_buffer_len=None, + mrope_positions=None, + spec_algorithm=None, + spec_info=None, + capture_hidden_mode=CaptureHiddenMode.NULL, + num_token_non_padded=None, + global_forward_mode=ForwardMode.EXTEND, + lora_ids=None, + ) + self.tbo_plugin.capture_one_batch_size(forward_batch, num_tokens=num_tokens) + + if lora_ids is not None: + self.model_runner.lora_manager.prepare_lora_batch(forward_batch) + + # # Attention backend + self.model_runner.attn_backend.init_forward_metadata(forward_batch) + + # Run and capture + def run_once(): + # Clean intermediate result cache for DP attention + forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None + set_dp_buffer_len(global_dp_buffer_len, num_tokens) + + kwargs = {} + with set_forward_context(forward_batch, self.attention_layers): + self.model_runner.model.forward( + forward_batch.input_ids, + forward_batch.positions, + forward_batch, + **kwargs, + ) + return + + for _ in range(2): + self.device_module.synchronize() + self.model_runner.tp_group.barrier() + run_once() + + return + + def replay_prepare( + self, + forward_batch: ForwardBatch, + **kwargs, + ): + num_tokens = len(forward_batch.input_ids) + index = bisect.bisect_left(self.capture_num_tokens, num_tokens) + static_num_tokens = self.capture_num_tokens[index] + self.raw_num_tokens = num_tokens + if static_num_tokens != num_tokens: + self.out_cache_loc.zero_() + bs = forward_batch.batch_size + + self.input_ids[:num_tokens].copy_(forward_batch.input_ids) + self.positions[:num_tokens].copy_(forward_batch.positions) + self.out_cache_loc[:num_tokens].copy_(forward_batch.out_cache_loc) + + input_ids = self.input_ids[:static_num_tokens] + positions = self.positions[:static_num_tokens] + out_cache_loc = self.out_cache_loc[:static_num_tokens] + + next_token_logits_buffer = None + mrope_positions = None + + static_forward_batch = ForwardBatch( + forward_mode=forward_batch.forward_mode, + batch_size=bs, + input_ids=input_ids, + req_pool_indices=forward_batch.req_pool_indices, + seq_lens=forward_batch.seq_lens, + next_token_logits_buffer=next_token_logits_buffer, + orig_seq_lens=forward_batch.orig_seq_lens, + seq_lens_cpu=forward_batch.seq_lens_cpu, + req_to_token_pool=self.model_runner.req_to_token_pool, + token_to_kv_pool=self.model_runner.token_to_kv_pool, + attn_backend=self.model_runner.attn_backend, + out_cache_loc=out_cache_loc, + seq_lens_sum=forward_batch.seq_lens_sum, + encoder_lens=forward_batch.encoder_lens, + return_logprob=forward_batch.return_logprob, + extend_seq_lens=forward_batch.extend_seq_lens, + extend_prefix_lens=forward_batch.extend_prefix_lens, + extend_start_loc=forward_batch.extend_start_loc, + extend_prefix_lens_cpu=forward_batch.extend_prefix_lens_cpu, + extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu, + extend_logprob_start_lens_cpu=forward_batch.extend_logprob_start_lens_cpu, + extend_num_tokens=forward_batch.extend_num_tokens, + extend_input_logprob_token_ids_gpu=forward_batch.extend_input_logprob_token_ids_gpu, + positions=positions, + global_num_tokens_gpu=forward_batch.global_num_tokens_gpu, + global_num_tokens_for_logprob_gpu=forward_batch.global_num_tokens_for_logprob_gpu, + dp_padding_mode=forward_batch.dp_padding_mode, + global_dp_buffer_len=forward_batch.global_dp_buffer_len, + mrope_positions=mrope_positions, + spec_algorithm=forward_batch.spec_algorithm, + spec_info=forward_batch.spec_info, + capture_hidden_mode=forward_batch.capture_hidden_mode, + num_token_non_padded=forward_batch.num_token_non_padded, + global_forward_mode=forward_batch.global_forward_mode, + lora_ids=forward_batch.lora_ids, + sampling_info=forward_batch.sampling_info, + mm_inputs=forward_batch.mm_inputs, + temp_scaled_logprobs=forward_batch.temp_scaled_logprobs, + temperature=forward_batch.temperature, + top_p_normalized_logprobs=forward_batch.top_p_normalized_logprobs, + top_p=forward_batch.top_p, + ) + + return static_forward_batch + + def replay( + self, + forward_batch: ForwardBatch, + **kwargs, + ) -> Union[LogitsProcessorOutput, PPProxyTensors]: + static_forward_batch = self.replay_prepare(forward_batch, **kwargs) + # Replay + with set_forward_context(static_forward_batch, self.attention_layers): + with set_compiled(True): + output = self.model_runner.model.forward( + static_forward_batch.input_ids, + static_forward_batch.positions, + static_forward_batch, + **kwargs, + ) + if isinstance(output, LogitsProcessorOutput): + return LogitsProcessorOutput( + next_token_logits=output.next_token_logits[: self.raw_num_tokens], + hidden_states=( + output.hidden_states[: self.raw_num_tokens] + if output.hidden_states is not None + else None + ), + ) + else: + assert isinstance(output, PPProxyTensors) + # TODO(Yuwei): support PP Support + raise NotImplementedError( + "PPProxyTensors is not supported in PiecewiseCudaGraphRunner yet." + ) + + def get_spec_info(self, num_tokens: int): + spec_info = None + if ( + self.model_runner.spec_algorithm.is_eagle() + or self.model_runner.spec_algorithm.is_standalone() + ): + from sglang.srt.speculative.eagle_utils import EagleVerifyInput + + if self.model_runner.is_draft_worker: + raise RuntimeError("This should not happen.") + else: + spec_info = EagleVerifyInput( + draft_token=None, + custom_mask=self.custom_mask, + positions=None, + retrive_index=None, + retrive_next_token=None, + retrive_next_sibling=None, + retrive_cum_len=None, + spec_steps=self.model_runner.server_args.speculative_num_steps, + topk=self.model_runner.server_args.speculative_eagle_topk, + draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens, + capture_hidden_mode=CaptureHiddenMode.FULL, + seq_lens_sum=None, + seq_lens_cpu=None, + ) + + return spec_info + + +PIECEWISE_CUDA_GRAPH_CAPTURE_FAILED_MSG = ( + "Possible solutions:\n" + "1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n" + "2. set --piecewise-cuda-graph-max-tokens to a smaller value (e.g., 512)\n" + "3. disable Piecewise CUDA graph by unset --enable-piecewise-cuda-graph\n" + "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n" +) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index c08046130..29d5cc03e 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -417,7 +417,10 @@ class ServerArgs: enable_single_batch_overlap: bool = False tbo_token_distribution_threshold: float = 0.48 enable_torch_compile: bool = False + enable_piecewise_cuda_graph: bool = False torch_compile_max_bs: int = 32 + piecewise_cuda_graph_max_tokens: int = 4096 + piecewise_cuda_graph_tokens: Optional[List[int]] = None torchao_config: str = "" enable_nan_detection: bool = False enable_p2p_check: bool = False @@ -675,6 +678,11 @@ class ServerArgs: else: self.cuda_graph_max_bs = max(self.cuda_graph_bs) + if self.piecewise_cuda_graph_tokens is None: + self.piecewise_cuda_graph_tokens = ( + self._generate_piecewise_cuda_graph_tokens() + ) + if self.mem_fraction_static is None: # Constant meta data (e.g., from attention backend) reserved_mem = 512 @@ -753,6 +761,25 @@ class ServerArgs: return capture_bs + def _generate_piecewise_cuda_graph_tokens(self): + """ + Generate the list of batch sizes for piecewise CUDA graph capture + based on piecewise_cuda_graph_max_tokens. + """ + capture_sizes = ( + list(range(4, 33, 4)) + + list(range(48, 257, 16)) + + list(range(288, 513, 32)) + + list(range(640, 4096 + 1, 128)) + + list(range(4352, self.piecewise_cuda_graph_max_tokens + 1, 256)) + ) + + capture_sizes = [ + s for s in capture_sizes if s <= self.piecewise_cuda_graph_max_tokens + ] + + return capture_sizes + def _handle_hpu_backends(self): if self.device == "hpu": self.attention_backend = "torch_native" @@ -2649,12 +2676,29 @@ class ServerArgs: action="store_true", help="Optimize the model with torch.compile. Experimental feature.", ) + parser.add_argument( + "--enable-piecewise-cuda-graph", + action="store_true", + help="Optimize the model with piecewise cuda graph for extend/prefill only. Experimental feature.", + ) + parser.add_argument( + "--piecewise-cuda-graph-tokens", + type=json_list_type, + default=ServerArgs.piecewise_cuda_graph_tokens, + help="Set the list of tokens when using piecewise cuda graph.", + ) parser.add_argument( "--torch-compile-max-bs", type=int, default=ServerArgs.torch_compile_max_bs, help="Set the maximum batch size when using torch compile.", ) + parser.add_argument( + "--piecewise-cuda-graph-max-tokens", + type=int, + default=ServerArgs.piecewise_cuda_graph_max_tokens, + help="Set the maximum tokens when using piecewise cuda graph.", + ) parser.add_argument( "--torchao-config", type=str, diff --git a/test/srt/test_piecewise_cuda_graph.py b/test/srt/test_piecewise_cuda_graph.py new file mode 100644 index 000000000..ed41e1e04 --- /dev/null +++ b/test/srt/test_piecewise_cuda_graph.py @@ -0,0 +1,59 @@ +import time +import unittest + +import requests + +from sglang.srt.utils import kill_process_tree +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + SimpleNamespace, + popen_launch_server, + run_bench_one_batch, +) + + +class TestPiecewiseCudaGraphCorrectness(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=["--enable-piecewise-cuda-graph"], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gpqa(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="gpqa", + num_examples=64, + num_threads=16, + ) + + metrics = run_eval(args) + self.assertGreaterEqual(metrics["score"], 0.235) + + +class TestPiecewiseCudaGraphBenchmark(CustomTestCase): + + def test_latency(self): + prefill_latency, _, _ = run_bench_one_batch( + DEFAULT_MODEL_NAME_FOR_TEST, + other_args=["--enable-piecewise-cuda-graph"], + ) + self.assertLess(prefill_latency, 0.015) + + +if __name__ == "__main__": + unittest.main()