Piecewise CUDA Graph Support & Torch Compile Backend (#10062)
Signed-off-by: Oasis-Git <ayw.sirius19@gmail.com>
This commit is contained in:
@@ -185,7 +185,7 @@ class CustomAllreduce:
|
||||
# is enough for 131072 such tuples. The largest model I've seen only
|
||||
# needs less than 10000 of registered tuples.
|
||||
self.rank_data = torch.empty(
|
||||
8 * 1024 * 1024, dtype=torch.uint8, device=self.device
|
||||
max_size, dtype=torch.uint8, device=self.device
|
||||
)
|
||||
self._ptr = ops.init_custom_ar(
|
||||
self.meta_ptrs, self.rank_data, rank, self.full_nvlink
|
||||
@@ -202,7 +202,7 @@ class CustomAllreduce:
|
||||
)
|
||||
handles, offsets = self._gather_ipc_meta(shard_data)
|
||||
self.rank_data = torch.empty(
|
||||
8 * 1024 * 1024, dtype=torch.uint8, device=self.device
|
||||
max_size, dtype=torch.uint8, device=self.device
|
||||
)
|
||||
self._ptr = ops.init_custom_ar(
|
||||
self.meta, self.rank_data, handles, offsets, rank, self.full_nvlink
|
||||
|
||||
@@ -239,6 +239,7 @@ class GroupCoordinator:
|
||||
use_npu_communicator: bool,
|
||||
use_message_queue_broadcaster: bool = False,
|
||||
group_name: Optional[str] = None,
|
||||
torch_compile: Optional[bool] = None,
|
||||
):
|
||||
# Set group info
|
||||
group_name = group_name or "anonymous"
|
||||
@@ -326,10 +327,18 @@ class GroupCoordinator:
|
||||
self.qr_comm: Optional[QuickAllReduce] = None
|
||||
if use_custom_allreduce and self.world_size > 1:
|
||||
# Initialize a custom fast all-reduce implementation.
|
||||
if torch_compile is not None and torch_compile:
|
||||
# For piecewise CUDA graph, the requirement for custom allreduce is larger to
|
||||
# avoid illegal cuda memory access.
|
||||
ca_max_size = 256 * 1024 * 1024
|
||||
else:
|
||||
ca_max_size = 8 * 1024 * 1024
|
||||
try:
|
||||
# print(f"ca_max_size: {ca_max_size}")
|
||||
self.ca_comm = CustomAllreduce(
|
||||
group=self.cpu_group,
|
||||
device=self.device,
|
||||
max_size=ca_max_size,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
@@ -1260,6 +1269,7 @@ def init_model_parallel_group(
|
||||
group_name: Optional[str] = None,
|
||||
use_mscclpp_allreduce: Optional[bool] = None,
|
||||
use_symm_mem_allreduce: Optional[bool] = None,
|
||||
torch_compile: Optional[bool] = None,
|
||||
) -> GroupCoordinator:
|
||||
if use_custom_allreduce is None:
|
||||
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
|
||||
@@ -1280,6 +1290,7 @@ def init_model_parallel_group(
|
||||
use_npu_communicator=True,
|
||||
use_message_queue_broadcaster=use_message_queue_broadcaster,
|
||||
group_name=group_name,
|
||||
torch_compile=torch_compile,
|
||||
)
|
||||
|
||||
|
||||
@@ -1439,6 +1450,7 @@ def initialize_model_parallel(
|
||||
pipeline_model_parallel_size: int = 1,
|
||||
backend: Optional[str] = None,
|
||||
duplicate_tp_group: bool = False,
|
||||
torch_compile: Optional[bool] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize model parallel groups.
|
||||
@@ -1494,6 +1506,7 @@ def initialize_model_parallel(
|
||||
"SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true"
|
||||
),
|
||||
group_name="tp",
|
||||
torch_compile=torch_compile,
|
||||
)
|
||||
|
||||
if duplicate_tp_group:
|
||||
@@ -1509,6 +1522,7 @@ def initialize_model_parallel(
|
||||
"SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true"
|
||||
),
|
||||
group_name="pdmux_prefill_tp",
|
||||
torch_compile=torch_compile,
|
||||
)
|
||||
_TP.pynccl_comm.disabled = False
|
||||
_PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False
|
||||
@@ -1518,7 +1532,6 @@ def initialize_model_parallel(
|
||||
|
||||
global _MOE_EP
|
||||
assert _MOE_EP is None, "expert model parallel group is already initialized"
|
||||
|
||||
if moe_ep_size == tensor_model_parallel_size:
|
||||
_MOE_EP = _TP
|
||||
else:
|
||||
@@ -1539,7 +1552,6 @@ def initialize_model_parallel(
|
||||
|
||||
global _MOE_TP
|
||||
assert _MOE_TP is None, "expert model parallel group is already initialized"
|
||||
|
||||
if moe_tp_size == tensor_model_parallel_size:
|
||||
_MOE_TP = _TP
|
||||
else:
|
||||
|
||||
@@ -43,11 +43,16 @@ _is_cpu = is_cpu()
|
||||
_is_xpu = is_xpu()
|
||||
|
||||
if _is_cuda:
|
||||
if _is_flashinfer_available:
|
||||
from flashinfer.norm import fused_add_rmsnorm
|
||||
else:
|
||||
from sgl_kernel import fused_add_rmsnorm
|
||||
from sgl_kernel import gemma_fused_add_rmsnorm, gemma_rmsnorm, rmsnorm
|
||||
# if _is_flashinfer_available:
|
||||
# from flashinfer.norm import fused_add_rmsnorm
|
||||
# else:
|
||||
from sgl_kernel import (
|
||||
fused_add_rmsnorm,
|
||||
gemma_fused_add_rmsnorm,
|
||||
gemma_rmsnorm,
|
||||
rmsnorm,
|
||||
)
|
||||
|
||||
|
||||
if _use_aiter:
|
||||
from aiter import rmsnorm2d_fwd as rms_norm
|
||||
|
||||
@@ -17,12 +17,18 @@ from __future__ import annotations
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
from sglang.srt.model_executor.compilation.piecewise_context_manager import (
|
||||
get_forward_context,
|
||||
)
|
||||
from sglang.srt.utils import direct_register_custom_op
|
||||
|
||||
|
||||
class AttentionType(Enum):
|
||||
"""
|
||||
@@ -105,12 +111,58 @@ class RadixAttention(nn.Module):
|
||||
else:
|
||||
k = k.view(-1, self.tp_k_head_num, self.v_head_dim)
|
||||
|
||||
return forward_batch.attn_backend.forward(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
self,
|
||||
forward_batch,
|
||||
save_kv_cache,
|
||||
**kwargs,
|
||||
)
|
||||
if forward_batch.forward_mode.is_extend() and get_forward_context() is not None:
|
||||
output = torch.zeros_like(q)
|
||||
torch.ops.sglang.unified_attention_with_output(
|
||||
q, k, v, output, save_kv_cache, self.layer_id
|
||||
)
|
||||
return output
|
||||
else:
|
||||
return forward_batch.attn_backend.forward(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
self,
|
||||
forward_batch,
|
||||
save_kv_cache,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def unified_attention_with_output(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
save_kv_cache: bool,
|
||||
layer_id: int,
|
||||
) -> None:
|
||||
context = get_forward_context()
|
||||
forward_batch = context.forward_batch
|
||||
attention_layers = context.attention_layers
|
||||
attention_layer = attention_layers[layer_id]
|
||||
ret = forward_batch.attn_backend.forward(
|
||||
query, key, value, attention_layer, forward_batch, save_kv_cache
|
||||
)
|
||||
assert output.shape == ret.shape
|
||||
output.copy_(ret)
|
||||
return
|
||||
|
||||
|
||||
def unified_attention_with_output_fake(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
save_kv_cache: bool,
|
||||
layer_id: int,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="unified_attention_with_output",
|
||||
op_func=unified_attention_with_output,
|
||||
mutates_args=["output"],
|
||||
fake_impl=unified_attention_with_output_fake,
|
||||
)
|
||||
|
||||
435
python/sglang/srt/model_executor/compilation/backend.py
Normal file
435
python/sglang/srt/model_executor/compilation/backend.py
Normal file
@@ -0,0 +1,435 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/backend.py
|
||||
|
||||
|
||||
import ast
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
import pprint
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
|
||||
from sglang.srt.model_executor.compilation.compilation_config import CompilationConfig
|
||||
from sglang.srt.model_executor.compilation.compilation_counter import (
|
||||
compilation_counter,
|
||||
)
|
||||
from sglang.srt.model_executor.compilation.compiler_interface import InductorAdaptor
|
||||
from sglang.srt.model_executor.compilation.cuda_piecewise_backend import (
|
||||
CUDAPiecewiseBackend,
|
||||
)
|
||||
from sglang.srt.model_executor.compilation.pass_manager import PostGradPassManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def make_compiler():
|
||||
return InductorAdaptor()
|
||||
|
||||
|
||||
class CompilerManager:
|
||||
def __init__(
|
||||
self,
|
||||
):
|
||||
self.cache = dict()
|
||||
self.is_cache_updated = False
|
||||
self.compiler = make_compiler()
|
||||
|
||||
def compute_hash(self):
|
||||
return self.compiler.compute_hash()
|
||||
|
||||
def initialize_cache(
|
||||
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
|
||||
):
|
||||
self.disable_cache = disable_cache
|
||||
self.cache_dir = cache_dir
|
||||
self.cache_file_path = os.path.join(cache_dir, "sglang_compile_cache.py")
|
||||
|
||||
if not disable_cache and os.path.exists(self.cache_file_path):
|
||||
with open(self.cache_file_path) as f:
|
||||
self.cache = ast.literal_eval(f.read())
|
||||
|
||||
self.compiler.initialize_cache(
|
||||
cache_dir=cache_dir, disable_cache=disable_cache, prefix=prefix
|
||||
)
|
||||
|
||||
def save_to_file(self):
|
||||
if self.disable_cache or not self.is_cache_updated:
|
||||
return
|
||||
printer = pprint.PrettyPrinter(indent=4)
|
||||
data = printer.pformat(self.cache)
|
||||
with open(self.cache_file_path, "w") as f:
|
||||
f.write(data)
|
||||
|
||||
def load(
|
||||
self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
runtime_shape: Optional[int] = None,
|
||||
) -> Optional[Callable]:
|
||||
handle = self.cache[(runtime_shape, graph_index, self.compiler.name)]
|
||||
compiled_graph = self.compiler.load(
|
||||
handle, graph, example_inputs, graph_index, runtime_shape
|
||||
)
|
||||
if runtime_shape is None:
|
||||
logger.debug(
|
||||
"Directly load the %s-th graph for dynamic shape from %s via "
|
||||
"handle %s",
|
||||
graph_index,
|
||||
self.compiler.name,
|
||||
handle,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"Directly load the %s-th graph for shape %s from %s via " "handle %s",
|
||||
graph_index,
|
||||
str(runtime_shape),
|
||||
self.compiler.name,
|
||||
handle,
|
||||
)
|
||||
return compiled_graph
|
||||
|
||||
def compile(
|
||||
self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs,
|
||||
inductor_config: dict[str, Any],
|
||||
graph_index: int = 0,
|
||||
num_graphs: int = 1,
|
||||
runtime_shape: Optional[int] = None,
|
||||
) -> Any:
|
||||
if graph_index == 0:
|
||||
# before compiling the first graph, record the start time
|
||||
global compilation_start_time
|
||||
compilation_start_time = time.time()
|
||||
|
||||
compilation_counter.num_backend_compilations += 1
|
||||
|
||||
compiled_graph = None
|
||||
|
||||
# TODO(Yuwei): support cache loading
|
||||
|
||||
# no compiler cached the graph, or the cache is disabled,
|
||||
# we need to compile it
|
||||
if isinstance(self.compiler, InductorAdaptor):
|
||||
maybe_key = None
|
||||
else:
|
||||
maybe_key = f"artifact_shape_{runtime_shape}_subgraph_{graph_index}"
|
||||
compiled_graph, handle = self.compiler.compile(
|
||||
graph, example_inputs, inductor_config, runtime_shape, maybe_key
|
||||
)
|
||||
|
||||
assert compiled_graph is not None, "Failed to compile the graph"
|
||||
|
||||
# store the artifact in the cache
|
||||
if handle is not None:
|
||||
self.cache[(runtime_shape, graph_index, self.compiler.name)] = handle
|
||||
compilation_counter.num_cache_entries_updated += 1
|
||||
self.is_cache_updated = True
|
||||
if graph_index == 0:
|
||||
# adds some info logging for the first graph
|
||||
if runtime_shape is None:
|
||||
logger.info("Cache the graph for dynamic shape for later use")
|
||||
else:
|
||||
logger.info(
|
||||
"Cache the graph of shape %s for later use", str(runtime_shape)
|
||||
)
|
||||
if runtime_shape is None:
|
||||
logger.debug(
|
||||
"Store the %s-th graph for dynamic shape from %s via " "handle %s",
|
||||
graph_index,
|
||||
self.compiler.name,
|
||||
handle,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"Store the %s-th graph for shape %s from %s via handle %s",
|
||||
graph_index,
|
||||
str(runtime_shape),
|
||||
self.compiler.name,
|
||||
handle,
|
||||
)
|
||||
|
||||
# after compiling the last graph, record the end time
|
||||
if graph_index == num_graphs - 1:
|
||||
now = time.time()
|
||||
elapsed = now - compilation_start_time
|
||||
if runtime_shape is None:
|
||||
logger.info("Compiling a graph for dynamic shape takes %.2f s", elapsed)
|
||||
else:
|
||||
logger.info(
|
||||
"Compiling a graph for shape %s takes %.2f s",
|
||||
runtime_shape,
|
||||
elapsed,
|
||||
)
|
||||
|
||||
return compiled_graph
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SplitItem:
|
||||
submod_name: str
|
||||
graph_id: int
|
||||
is_splitting_graph: bool
|
||||
graph: fx.GraphModule
|
||||
|
||||
|
||||
def split_graph(
|
||||
graph: fx.GraphModule, ops: list[str]
|
||||
) -> tuple[fx.GraphModule, list[SplitItem]]:
|
||||
# split graph by ops
|
||||
subgraph_id = 0
|
||||
node_to_subgraph_id = {}
|
||||
split_op_graphs = []
|
||||
for node in graph.graph.nodes:
|
||||
if node.op in ("output", "placeholder"):
|
||||
continue
|
||||
if node.op == "call_function" and str(node.target) in ops:
|
||||
subgraph_id += 1
|
||||
node_to_subgraph_id[node] = subgraph_id
|
||||
split_op_graphs.append(subgraph_id)
|
||||
subgraph_id += 1
|
||||
else:
|
||||
node_to_subgraph_id[node] = subgraph_id
|
||||
|
||||
# `keep_original_order` is important!
|
||||
# otherwise pytorch might reorder the nodes and
|
||||
# the semantics of the graph will change when we
|
||||
# have mutations in the graph
|
||||
split_gm = torch.fx.passes.split_module.split_module(
|
||||
graph, None, lambda node: node_to_subgraph_id[node], keep_original_order=True
|
||||
)
|
||||
|
||||
outputs = []
|
||||
|
||||
names = [name for (name, module) in split_gm.named_modules()]
|
||||
|
||||
for name in names:
|
||||
if "." in name or name == "":
|
||||
# recursive child module or the root module
|
||||
continue
|
||||
|
||||
module = getattr(split_gm, name)
|
||||
|
||||
graph_id = int(name.replace("submod_", ""))
|
||||
outputs.append(SplitItem(name, graph_id, (graph_id in split_op_graphs), module))
|
||||
|
||||
# sort by intetger graph_id, rather than string name
|
||||
outputs.sort(key=lambda x: x.graph_id)
|
||||
|
||||
return split_gm, outputs
|
||||
|
||||
|
||||
# we share the global graph pool among all the backends
|
||||
global_graph_pool = None
|
||||
|
||||
compilation_start_time = 0.0
|
||||
|
||||
|
||||
class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
||||
def __init__(
|
||||
self,
|
||||
module: torch.fx.GraphModule,
|
||||
compile_submod_names: list[str],
|
||||
inductor_config: dict[str, Any],
|
||||
graph_pool,
|
||||
compile_config: CompilationConfig,
|
||||
sglang_backend: "SGLangBackend",
|
||||
):
|
||||
super().__init__(module)
|
||||
from torch._guards import detect_fake_mode
|
||||
|
||||
self.fake_mode = detect_fake_mode()
|
||||
self.compile_submod_names = compile_submod_names
|
||||
self.graph_pool = graph_pool
|
||||
self.sglang_backend = sglang_backend
|
||||
# When True, it annoyingly dumps the torch.fx.Graph on errors.
|
||||
self.extra_traceback = False
|
||||
self.inductor_config = inductor_config
|
||||
self.compile_config = compile_config
|
||||
|
||||
def run(self, *args):
|
||||
fake_args = [
|
||||
self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
|
||||
for t in args
|
||||
]
|
||||
with self.fake_mode, enable_python_dispatcher():
|
||||
return super().run(*fake_args)
|
||||
|
||||
def call_module(
|
||||
self,
|
||||
target: torch.fx.node.Target,
|
||||
args: tuple[torch.fx.node.Argument, ...],
|
||||
kwargs: dict[str, Any],
|
||||
) -> Any:
|
||||
assert isinstance(target, str)
|
||||
output = super().call_module(target, args, kwargs)
|
||||
|
||||
if target in self.compile_submod_names:
|
||||
index = self.compile_submod_names.index(target)
|
||||
submod = self.fetch_attr(target)
|
||||
sym_shape_indices = [
|
||||
i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
|
||||
]
|
||||
global compilation_start_time
|
||||
compiled_graph_for_dynamic_shape = (
|
||||
self.sglang_backend.compiler_manager.compile(
|
||||
submod,
|
||||
args,
|
||||
self.inductor_config,
|
||||
graph_index=index,
|
||||
num_graphs=len(self.compile_submod_names),
|
||||
runtime_shape=None,
|
||||
)
|
||||
)
|
||||
|
||||
self.module.__dict__[target] = CUDAPiecewiseBackend(
|
||||
submod,
|
||||
self.compile_config,
|
||||
self.inductor_config,
|
||||
self.graph_pool,
|
||||
index,
|
||||
len(self.compile_submod_names),
|
||||
sym_shape_indices,
|
||||
compiled_graph_for_dynamic_shape,
|
||||
self.sglang_backend,
|
||||
)
|
||||
|
||||
compilation_counter.num_piecewise_capturable_graphs_seen += 1
|
||||
|
||||
return output
|
||||
|
||||
|
||||
model_tag: str = "backbone"
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_model_tag(tag: str):
|
||||
"""Context manager to set the model tag."""
|
||||
global model_tag
|
||||
assert (
|
||||
tag != model_tag
|
||||
), f"Model tag {tag} is the same as the current tag {model_tag}."
|
||||
old_tag = model_tag
|
||||
model_tag = tag
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
model_tag = old_tag
|
||||
|
||||
|
||||
class SGLangBackend:
|
||||
|
||||
graph_pool: Any
|
||||
_called: bool = False
|
||||
# the graph we compiled
|
||||
graph: fx.GraphModule
|
||||
# the stiching graph module for all the piecewise graphs
|
||||
split_gm: fx.GraphModule
|
||||
piecewise_graphs: list[SplitItem]
|
||||
returned_callable: Callable
|
||||
# Inductor passes to run on the graph pre-defunctionalization
|
||||
post_grad_passes: Sequence[Callable]
|
||||
sym_tensor_indices: list[int]
|
||||
input_buffers: list[torch.Tensor]
|
||||
compiler_manager: CompilerManager
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: CompilationConfig,
|
||||
graph_pool: Any,
|
||||
):
|
||||
assert graph_pool is not None
|
||||
self.graph_pool = graph_pool
|
||||
|
||||
self.post_grad_pass_manager = PostGradPassManager()
|
||||
self.sym_tensor_indices = []
|
||||
self.input_buffers = []
|
||||
|
||||
self.compiler_manager = CompilerManager()
|
||||
self.inductor_config = {
|
||||
"enable_auto_functionalized_v2": False,
|
||||
}
|
||||
self.compile_config = config
|
||||
|
||||
def configure_post_pass(self):
|
||||
self.post_grad_pass_manager.configure()
|
||||
self.inductor_config["post_grad_custom_post_pass"] = self.post_grad_pass_manager
|
||||
|
||||
def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
|
||||
base_cache_dir = os.path.expanduser(
|
||||
os.getenv("SGLANG_CACHE_DIR", "~/.cache/sglang/")
|
||||
)
|
||||
|
||||
cache_hash = self.compiler_manager.compute_hash()
|
||||
cache_dir = os.path.join(
|
||||
base_cache_dir,
|
||||
"torch_compile_cache",
|
||||
cache_hash,
|
||||
)
|
||||
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
rank = 0
|
||||
dp_rank = 0
|
||||
local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}", model_tag)
|
||||
os.makedirs(local_cache_dir, exist_ok=True)
|
||||
self.compiler_manager.initialize_cache(
|
||||
local_cache_dir, disable_cache=False, prefix=""
|
||||
)
|
||||
compilation_counter.num_graphs_seen += 1
|
||||
|
||||
assert not self._called, "SGLangBackend can only be called once"
|
||||
|
||||
self.graph = graph
|
||||
self.configure_post_pass()
|
||||
|
||||
self.split_gm, self.piecewise_graphs = split_graph(
|
||||
graph, ["sglang.unified_attention_with_output"]
|
||||
)
|
||||
|
||||
from torch._dynamo.utils import lazy_format_graph_code
|
||||
|
||||
# depyf will hook lazy_format_graph_code and dump the graph
|
||||
# for debugging, no need to print the graph here
|
||||
lazy_format_graph_code("before split", self.graph)
|
||||
lazy_format_graph_code("after split", self.split_gm)
|
||||
|
||||
compilation_counter.num_piecewise_graphs_seen += len(self.piecewise_graphs)
|
||||
|
||||
submod_names_to_compile = [
|
||||
item.submod_name
|
||||
for item in self.piecewise_graphs
|
||||
if not item.is_splitting_graph
|
||||
]
|
||||
|
||||
PiecewiseCompileInterpreter(
|
||||
self.split_gm,
|
||||
submod_names_to_compile,
|
||||
self.inductor_config,
|
||||
self.graph_pool,
|
||||
self.compile_config,
|
||||
self,
|
||||
).run(*example_inputs)
|
||||
|
||||
graph_path = os.path.join(local_cache_dir, "computation_graph.py")
|
||||
if not os.path.exists(graph_path):
|
||||
# code adapted from https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30 # noqa
|
||||
# use `print_readable` because it can include submodules
|
||||
src = (
|
||||
"from __future__ import annotations\nimport torch\n"
|
||||
+ self.split_gm.print_readable(print_output=False)
|
||||
)
|
||||
src = src.replace("<lambda>", "GraphModule")
|
||||
with open(graph_path, "w") as f:
|
||||
f.write(src)
|
||||
|
||||
logger.debug("Computation graph saved to %s", graph_path)
|
||||
|
||||
self._called = True
|
||||
return self.split_gm
|
||||
@@ -0,0 +1,19 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/compilation_config.py
|
||||
|
||||
from typing import List
|
||||
|
||||
|
||||
# TODO(Yuwei): support better compile config support
|
||||
class CompilationConfig:
|
||||
def __init__(self, capture_sizes: List[int]):
|
||||
self.traced_files = set()
|
||||
self.capture_sizes = capture_sizes
|
||||
|
||||
def add_traced_file(self, file_path: str):
|
||||
self.traced_files.add(file_path)
|
||||
|
||||
def get_traced_files(self):
|
||||
return self.traced_files
|
||||
|
||||
def get_capture_sizes(self):
|
||||
return self.capture_sizes
|
||||
@@ -0,0 +1,47 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/compilation_counter.py
|
||||
|
||||
import copy
|
||||
import dataclasses
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CompilationCounter:
|
||||
num_models_seen: int = 0
|
||||
num_graphs_seen: int = 0
|
||||
# including the splitting ops
|
||||
num_piecewise_graphs_seen: int = 0
|
||||
# not including the splitting ops
|
||||
num_piecewise_capturable_graphs_seen: int = 0
|
||||
num_backend_compilations: int = 0
|
||||
# Number of gpu_model_runner attempts to trigger CUDAGraphs capture
|
||||
num_gpu_runner_capture_triggers: int = 0
|
||||
# Number of CUDAGraphs captured
|
||||
num_cudagraph_captured: int = 0
|
||||
# InductorAdapter.compile calls
|
||||
num_inductor_compiles: int = 0
|
||||
# EagerAdapter.compile calls
|
||||
num_eager_compiles: int = 0
|
||||
# The number of time vLLM's compiler cache entry was updated
|
||||
num_cache_entries_updated: int = 0
|
||||
# The number of standalone_compile compiled artifacts saved
|
||||
num_compiled_artifacts_saved: int = 0
|
||||
# Number of times a model was loaded with CompilationLevel.DYNAMO_AS_IS
|
||||
dynamo_as_is_count: int = 0
|
||||
|
||||
def clone(self) -> "CompilationCounter":
|
||||
return copy.deepcopy(self)
|
||||
|
||||
@contextmanager
|
||||
def expect(self, **kwargs):
|
||||
old = self.clone()
|
||||
yield
|
||||
for k, v in kwargs.items():
|
||||
assert getattr(self, k) - getattr(old, k) == v, (
|
||||
f"{k} not as expected, before it is {getattr(old, k)}"
|
||||
f", after it is {getattr(self, k)}, "
|
||||
f"expected diff is {v}"
|
||||
)
|
||||
|
||||
|
||||
compilation_counter = CompilationCounter()
|
||||
210
python/sglang/srt/model_executor/compilation/compile.py
Normal file
210
python/sglang/srt/model_executor/compilation/compile.py
Normal file
@@ -0,0 +1,210 @@
|
||||
import contextvars
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import types
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.model_executor.compilation.compilation_config import CompilationConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_COMPILE_ENABLED = contextvars.ContextVar("_COMPILE_ENABLED", default=False)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_compiled(enabled: bool = True):
|
||||
token = _COMPILE_ENABLED.set(enabled)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_COMPILE_ENABLED.reset(token)
|
||||
|
||||
|
||||
@dataclass
|
||||
class IntermediateTensors:
|
||||
"""For all pipeline stages except the last, we need to return the hidden
|
||||
states and residuals to be sent to the next stage. This data structure
|
||||
contains the hidden states and residuals for a request.
|
||||
|
||||
Each stage also needs to handle its own finished_sending and
|
||||
finished_recving in case of kv transfer.
|
||||
"""
|
||||
|
||||
tensors: dict[str, torch.Tensor]
|
||||
# [req_ids]
|
||||
finished_sending: Optional[set[str]] = None
|
||||
finished_recving: Optional[set[str]] = None
|
||||
|
||||
def __init__(self, tensors):
|
||||
# manually define this function, so that
|
||||
# Dynamo knows `IntermediateTensors()` comes from this file.
|
||||
# Otherwise, dataclass will generate this function by evaluating
|
||||
# a string, and we will lose the information about the source file.
|
||||
self.tensors = tensors
|
||||
|
||||
def __getitem__(self, key: Union[str, slice]):
|
||||
if isinstance(key, str):
|
||||
return self.tensors[key]
|
||||
elif isinstance(key, slice):
|
||||
return self.__class__({k: v[key] for k, v in self.tensors.items()})
|
||||
|
||||
def __setitem__(self, key: str, value: torch.Tensor):
|
||||
self.tensors[key] = value
|
||||
|
||||
def items(self):
|
||||
return self.tensors.items()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.tensors)
|
||||
|
||||
def __eq__(self, other: object):
|
||||
return isinstance(other, self.__class__) and self
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"IntermediateTensors(tensors={self.tensors})"
|
||||
|
||||
|
||||
def _normalize_dims(dims, ndim: int):
|
||||
dims = [dims] if isinstance(dims, int) else list(dims)
|
||||
return [d if d >= 0 else ndim + d for d in dims]
|
||||
|
||||
|
||||
class _MaybeIntermediateTensors:
|
||||
"""Duck-typed check to support your IntermediateTensors without importing."""
|
||||
|
||||
def __init__(self, obj):
|
||||
self.is_intermediate = hasattr(obj, "tensors") and isinstance(
|
||||
getattr(obj, "tensors"), dict
|
||||
)
|
||||
self.obj = obj
|
||||
|
||||
|
||||
def _mark_dynamic_on_value(val, dims):
|
||||
if isinstance(val, torch.Tensor):
|
||||
torch._dynamo.mark_dynamic(val, _normalize_dims(dims, val.ndim))
|
||||
else:
|
||||
mit = _MaybeIntermediateTensors(val)
|
||||
if mit.is_intermediate:
|
||||
for t in mit.obj.tensors.values():
|
||||
torch._dynamo.mark_dynamic(t, _normalize_dims(dims, t.ndim))
|
||||
# else: ignore (None or non-tensor)
|
||||
|
||||
|
||||
def _infer_dynamic_arg_dims_from_annotations(forward_fn):
|
||||
sig = inspect.signature(forward_fn)
|
||||
dyn = {}
|
||||
for name, p in sig.parameters.items():
|
||||
ann = p.annotation
|
||||
# Accept torch.Tensor / Optional[torch.Tensor] / your IntermediateTensors types by name
|
||||
if (
|
||||
ann is torch.Tensor
|
||||
or getattr(getattr(ann, "__args__", [None])[0], "__name__", "") == "Tensor"
|
||||
):
|
||||
dyn[name] = 0
|
||||
elif getattr(ann, "__name__", "") in ("IntermediateTensors",) or any(
|
||||
getattr(a, "__name__", "") == "IntermediateTensors"
|
||||
for a in getattr(ann, "__args__", [])
|
||||
):
|
||||
dyn[name] = 0
|
||||
if not dyn:
|
||||
raise ValueError("No dynamic dims inferred; pass dynamic_arg_dims explicitly.")
|
||||
return dyn
|
||||
|
||||
|
||||
def install_torch_compiled(
|
||||
module: torch.nn.Module,
|
||||
*,
|
||||
dynamic_arg_dims: dict[str, Union[int, list[int]]] | None = None,
|
||||
backend_factory: Optional[Callable[[torch.fx.GraphModule, list], Callable]] = None,
|
||||
compile_config: CompilationConfig = None,
|
||||
fullgraph: bool = True,
|
||||
graph_pool: Any = None,
|
||||
):
|
||||
unbound_fwd = module.__class__.forward
|
||||
if not callable(unbound_fwd):
|
||||
raise TypeError("module.__class__.forward must be callable")
|
||||
original_code = unbound_fwd.__code__
|
||||
|
||||
dyn_map = dynamic_arg_dims or _infer_dynamic_arg_dims_from_annotations(unbound_fwd)
|
||||
|
||||
if backend_factory is None:
|
||||
from sglang.srt.model_executor.compilation.backend import SGLangBackend
|
||||
|
||||
backend_factory = lambda gm, ex: SGLangBackend(compile_config, graph_pool)(
|
||||
gm, ex
|
||||
)
|
||||
|
||||
compiled_codes: list[type(original_code)] = []
|
||||
state = {"compiled": False, "compiled_callable": None}
|
||||
|
||||
def bytecode_hook(old_code, new_code):
|
||||
if old_code is not original_code:
|
||||
return
|
||||
frame = sys._getframe()
|
||||
while frame and frame.f_back:
|
||||
frame = frame.f_back
|
||||
if (
|
||||
frame.f_code.co_name == "_compile"
|
||||
and os.path.basename(frame.f_code.co_filename) == "convert_frame.py"
|
||||
):
|
||||
break
|
||||
try:
|
||||
dynamo_frame = frame.f_locals["frame"]
|
||||
except Exception:
|
||||
return
|
||||
if dynamo_frame.f_code is not old_code:
|
||||
return
|
||||
if dynamo_frame.f_locals.get("self") is not module:
|
||||
return
|
||||
compiled_codes.append(new_code)
|
||||
|
||||
torch._dynamo.convert_frame.register_bytecode_hook(bytecode_hook)
|
||||
|
||||
def _ensure_compiled(self, *args, **kwargs):
|
||||
"""Compile on first use (with flag ON)."""
|
||||
if state["compiled"]:
|
||||
return
|
||||
# Mark dynamic dims only when we are about to compile
|
||||
sig = inspect.signature(unbound_fwd)
|
||||
ba = sig.bind(self, *args, **kwargs)
|
||||
ba.apply_defaults()
|
||||
for name, dims in (dyn_map or {}).items():
|
||||
if name in ba.arguments:
|
||||
val = ba.arguments[name]
|
||||
if val is not None:
|
||||
_mark_dynamic_on_value(val, dims)
|
||||
|
||||
# Avoid cross-instance cache reuse
|
||||
torch._dynamo.eval_frame.remove_from_cache(unbound_fwd.__code__)
|
||||
|
||||
bound = types.MethodType(unbound_fwd, self)
|
||||
compiled_callable = torch.compile(
|
||||
bound, fullgraph=fullgraph, backend=backend_factory
|
||||
)
|
||||
|
||||
# Trigger Dynamo so bytecode hook can capture
|
||||
compiled_callable(*args, **kwargs)
|
||||
|
||||
state["compiled"] = True
|
||||
state["compiled_callable"] = compiled_callable
|
||||
|
||||
def trampoline(self, *args, **kwargs):
|
||||
use_compiled = _COMPILE_ENABLED.get()
|
||||
if use_compiled:
|
||||
if not state["compiled"]:
|
||||
_ensure_compiled(self, *args, **kwargs)
|
||||
|
||||
compiled_callable = state["compiled_callable"]
|
||||
return compiled_callable(*args, **kwargs)
|
||||
else:
|
||||
# Explicitly run the original uncompiled forward
|
||||
return unbound_fwd(self, *args, **kwargs)
|
||||
|
||||
module.forward = types.MethodType(trampoline, module)
|
||||
return module
|
||||
@@ -0,0 +1,479 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/compiler_interface.py
|
||||
|
||||
import contextlib
|
||||
import copy
|
||||
import hashlib
|
||||
import os
|
||||
from contextlib import ExitStack
|
||||
from typing import Any, Callable, Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch._inductor.compile_fx
|
||||
import torch.fx as fx
|
||||
|
||||
from sglang.srt.model_executor.compilation.compilation_counter import (
|
||||
compilation_counter,
|
||||
)
|
||||
from sglang.srt.model_executor.compilation.inductor_pass import pass_context
|
||||
|
||||
|
||||
class CompilerInterface:
|
||||
"""
|
||||
The interface for a compiler that can be used by vLLM.
|
||||
"""
|
||||
|
||||
# The name of the compiler, e.g. inductor.
|
||||
# This is a class-level attribute.
|
||||
name: str
|
||||
|
||||
def initialize_cache(
|
||||
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
|
||||
):
|
||||
"""
|
||||
when the vLLM process uses `cache_dir` as the cache directory,
|
||||
the compiler should initialize itself with the cache directory,
|
||||
e.g. by re-directing its own cache directory to a sub-directory.
|
||||
|
||||
prefix can be used in combination with cache_dir to figure out the base
|
||||
cache directory, e.g. there're multiple parts of model being compiled,
|
||||
but we want to share the same cache directory for all of them.
|
||||
|
||||
e.g.
|
||||
cache_dir = "/path/to/dir/backbone", prefix = "backbone"
|
||||
cache_dir = "/path/to/dir/eagle_head", prefix = "eagle_head"
|
||||
"""
|
||||
pass
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
Gather all the relevant information from the vLLM config,
|
||||
to compute a hash so that we can cache the compiled model.
|
||||
|
||||
See [`VllmConfig.compute_hash`][vllm.config.VllmConfig.compute_hash]
|
||||
to check what information
|
||||
is already considered by default. This function should only
|
||||
consider the information that is specific to the compiler.
|
||||
"""
|
||||
return ""
|
||||
|
||||
def compile(
|
||||
self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
runtime_shape: Optional[int] = None,
|
||||
key: Optional[str] = None,
|
||||
) -> tuple[Optional[Callable], Optional[Any]]:
|
||||
"""
|
||||
Compile the graph with the given example inputs and compiler config,
|
||||
with a runtime shape. If the `runtime_shape` is None, it means
|
||||
the `example_inputs` have a dynamic shape. Otherwise, the
|
||||
`runtime_shape` specifies the shape of the inputs. Right now we only
|
||||
support one variable shape for all inputs, which is the batchsize
|
||||
(number of tokens) during inference.
|
||||
|
||||
Dynamo will make sure `graph(*example_inputs)` is valid.
|
||||
|
||||
The function should return a compiled callable function, as well as
|
||||
a handle that can be used to directly load the compiled function.
|
||||
|
||||
The handle should be a plain Python object, preferably a string or a
|
||||
file path for readability.
|
||||
|
||||
If the compiler doesn't support caching, it should return None for the
|
||||
handle. If the compiler fails to compile the graph, it should return
|
||||
None for the compiled function as well.
|
||||
|
||||
`key` is required for StandaloneInductorAdapter, it specifies where to
|
||||
save the compiled artifact. The compiled artifact gets saved to
|
||||
`cache_dir/key`.
|
||||
"""
|
||||
return None, None
|
||||
|
||||
def load(
|
||||
self,
|
||||
handle: Any,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
runtime_shape: Optional[int] = None,
|
||||
) -> Callable:
|
||||
"""
|
||||
Load the compiled function from the handle.
|
||||
Raises an error if the handle is invalid.
|
||||
|
||||
The handle is the second return value of the `compile` function.
|
||||
"""
|
||||
raise NotImplementedError("caching is not supported")
|
||||
|
||||
|
||||
def get_inductor_factors() -> list[Any]:
|
||||
factors: list[Any] = []
|
||||
# summarize system state
|
||||
from torch._inductor.codecache import CacheBase
|
||||
|
||||
system_factors = CacheBase.get_system()
|
||||
factors.append(system_factors)
|
||||
|
||||
# summarize pytorch state
|
||||
from torch._inductor.codecache import torch_key
|
||||
|
||||
torch_factors = torch_key()
|
||||
factors.append(torch_factors)
|
||||
return factors
|
||||
|
||||
|
||||
class AlwaysHitShapeEnv:
|
||||
"""
|
||||
Why do we need this class:
|
||||
|
||||
For normal `torch.compile` usage, every compilation will have
|
||||
one Dynamo bytecode compilation and one Inductor compilation.
|
||||
The Inductor compilation happens under the context of the
|
||||
Dynamo bytecode compilation, and that context is used to
|
||||
determine the dynamic shape information, etc.
|
||||
|
||||
For our use case, we only run Dynamo bytecode compilation once,
|
||||
and run Inductor compilation multiple times with different shapes
|
||||
plus a general shape. The compilation for specific shapes happens
|
||||
outside of the context of the Dynamo bytecode compilation. At that
|
||||
time, we don't have shape environment to provide to Inductor, and
|
||||
it will fail the Inductor code cache lookup.
|
||||
|
||||
By providing a dummy shape environment that always hits, we can
|
||||
make the Inductor code cache lookup always hit, and we can
|
||||
compile the graph for different shapes as needed.
|
||||
|
||||
The following dummy methods are obtained by trial-and-error
|
||||
until it works.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.guards: list[Any] = []
|
||||
|
||||
def evaluate_guards_expression(self, *args, **kwargs):
|
||||
return True
|
||||
|
||||
def get_pruned_guards(self, *args, **kwargs):
|
||||
return []
|
||||
|
||||
def produce_guards_expression(self, *args, **kwargs):
|
||||
return ""
|
||||
|
||||
|
||||
class InductorAdaptor(CompilerInterface):
|
||||
"""
|
||||
The adaptor for the Inductor compiler, version 2.5, 2.6, 2.7.
|
||||
"""
|
||||
|
||||
name = "inductor"
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
factors = get_inductor_factors()
|
||||
hash_str = hashlib.md5(
|
||||
str(factors).encode(), usedforsecurity=False
|
||||
).hexdigest()[:10]
|
||||
return hash_str
|
||||
|
||||
def initialize_cache(
|
||||
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
|
||||
):
|
||||
self.cache_dir = cache_dir
|
||||
self.prefix = prefix
|
||||
self.base_cache_dir = cache_dir[: -len(prefix)] if prefix else cache_dir
|
||||
if disable_cache:
|
||||
return
|
||||
# redirect the cache directory to a sub-directory
|
||||
# set flags so that Inductor and Triton store their cache
|
||||
# in the cache_dir, then users only need to copy the cache_dir
|
||||
# to another machine to reuse the cache.
|
||||
inductor_cache = os.path.join(self.base_cache_dir, "inductor_cache")
|
||||
os.makedirs(inductor_cache, exist_ok=True)
|
||||
os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache
|
||||
triton_cache = os.path.join(self.base_cache_dir, "triton_cache")
|
||||
os.makedirs(triton_cache, exist_ok=True)
|
||||
os.environ["TRITON_CACHE_DIR"] = triton_cache
|
||||
|
||||
def compile(
|
||||
self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
runtime_shape: Optional[int] = None,
|
||||
key: Optional[str] = None,
|
||||
) -> tuple[Optional[Callable], Optional[Any]]:
|
||||
compilation_counter.num_inductor_compiles += 1
|
||||
from torch._inductor.compile_fx import compile_fx
|
||||
|
||||
current_config = {}
|
||||
if compiler_config is not None:
|
||||
current_config.update(compiler_config)
|
||||
|
||||
# disable remote cache
|
||||
current_config["fx_graph_cache"] = True
|
||||
current_config["fx_graph_remote_cache"] = False
|
||||
|
||||
set_inductor_config(current_config, runtime_shape)
|
||||
|
||||
# inductor can inplace modify the graph, so we need to copy it
|
||||
# see https://github.com/pytorch/pytorch/issues/138980
|
||||
graph = copy.deepcopy(graph)
|
||||
|
||||
# it's the first time we compile this graph
|
||||
# the assumption is that we don't have nested Inductor compilation.
|
||||
# compiled_fx_graph_hash will only be called once, and we can hook
|
||||
# it to get the hash of the compiled graph directly.
|
||||
|
||||
hash_str, file_path = None, None
|
||||
from torch._inductor.codecache import FxGraphCache, compiled_fx_graph_hash
|
||||
|
||||
if torch.__version__.startswith("2.5"):
|
||||
original_load = FxGraphCache.load
|
||||
original_load_name = "torch._inductor.codecache.FxGraphCache.load"
|
||||
|
||||
def hijack_load(*args, **kwargs):
|
||||
inductor_compiled_graph = original_load(*args, **kwargs)
|
||||
nonlocal file_path
|
||||
compiled_fn = inductor_compiled_graph.current_callable
|
||||
file_path = compiled_fn.__code__.co_filename # noqa
|
||||
if not file_path.startswith(self.base_cache_dir):
|
||||
# hooked in the align_inputs_from_check_idxs function
|
||||
# in torch/_inductor/utils.py
|
||||
for cell in compiled_fn.__closure__:
|
||||
if not callable(cell.cell_contents):
|
||||
continue
|
||||
if cell.cell_contents.__code__.co_filename.startswith(
|
||||
self.base_cache_dir
|
||||
):
|
||||
# this is the real file path compiled from Inductor
|
||||
file_path = cell.cell_contents.__code__.co_filename
|
||||
break
|
||||
return inductor_compiled_graph
|
||||
|
||||
hijacked_compile_fx_inner = (
|
||||
torch._inductor.compile_fx.compile_fx_inner
|
||||
) # noqa
|
||||
elif torch.__version__ >= "2.6":
|
||||
# function renamed in 2.6
|
||||
original_load_name = None
|
||||
|
||||
def hijacked_compile_fx_inner(*args, **kwargs):
|
||||
output = torch._inductor.compile_fx.compile_fx_inner(*args, **kwargs)
|
||||
nonlocal hash_str
|
||||
inductor_compiled_graph = output
|
||||
if inductor_compiled_graph is not None:
|
||||
nonlocal file_path
|
||||
compiled_fn = inductor_compiled_graph.current_callable
|
||||
file_path = compiled_fn.__code__.co_filename # noqa
|
||||
if not file_path.startswith(self.base_cache_dir):
|
||||
# hooked in the align_inputs_from_check_idxs function
|
||||
# in torch/_inductor/utils.py
|
||||
for cell in compiled_fn.__closure__:
|
||||
if not callable(cell.cell_contents):
|
||||
continue
|
||||
code = cell.cell_contents.__code__
|
||||
if code.co_filename.startswith(self.base_cache_dir):
|
||||
# this is the real file path
|
||||
# compiled from Inductor
|
||||
file_path = code.co_filename
|
||||
break
|
||||
hash_str = inductor_compiled_graph._fx_graph_cache_key
|
||||
return output
|
||||
|
||||
def hijack_compiled_fx_graph_hash(*args, **kwargs):
|
||||
out = compiled_fx_graph_hash(*args, **kwargs)
|
||||
nonlocal hash_str
|
||||
hash_str = out[0]
|
||||
return out
|
||||
|
||||
def _check_can_cache(*args, **kwargs):
|
||||
# no error means it can be cached.
|
||||
# Inductor refuses to cache the graph outside of Dynamo
|
||||
# tracing context, and also disables caching for graphs
|
||||
# with high-order ops.
|
||||
# For vLLM, in either case, we want to cache the graph.
|
||||
# see https://github.com/pytorch/pytorch/blob/9f5ebf3fc609105a74eab4ccc24932d6353ff566/torch/_inductor/codecache.py#L1221 # noqa
|
||||
return
|
||||
|
||||
def _get_shape_env() -> AlwaysHitShapeEnv:
|
||||
return AlwaysHitShapeEnv()
|
||||
|
||||
with ExitStack() as stack:
|
||||
# hijack to get the compiled graph itself
|
||||
if original_load_name is not None:
|
||||
stack.enter_context(patch(original_load_name, hijack_load))
|
||||
|
||||
# for hijacking the hash of the compiled graph
|
||||
stack.enter_context(
|
||||
patch(
|
||||
"torch._inductor.codecache.compiled_fx_graph_hash",
|
||||
hijack_compiled_fx_graph_hash,
|
||||
)
|
||||
)
|
||||
|
||||
# for providing a dummy shape environment
|
||||
stack.enter_context(
|
||||
patch(
|
||||
"torch._inductor.codecache.FxGraphCache._get_shape_env",
|
||||
_get_shape_env,
|
||||
)
|
||||
)
|
||||
|
||||
from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
|
||||
|
||||
# torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
|
||||
if hasattr(AOTAutogradCache, "_get_shape_env"):
|
||||
stack.enter_context(
|
||||
patch(
|
||||
"torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env",
|
||||
_get_shape_env,
|
||||
)
|
||||
)
|
||||
|
||||
# for forcing the graph to be cached
|
||||
stack.enter_context(
|
||||
patch(
|
||||
"torch._inductor.codecache.FxGraphCache._check_can_cache",
|
||||
_check_can_cache,
|
||||
)
|
||||
)
|
||||
|
||||
# Dynamo metrics context, see method for more details.
|
||||
stack.enter_context(self.metrics_context())
|
||||
|
||||
# Disable remote caching. When these are on, on remote cache-hit,
|
||||
# the monkey-patched functions never actually get called.
|
||||
# vLLM today assumes and requires the monkey-patched functions to
|
||||
# get hit.
|
||||
# TODO(zou3519): we're going to replace this all with
|
||||
# standalone_compile sometime.
|
||||
|
||||
stack.enter_context(
|
||||
torch._inductor.config.patch(fx_graph_remote_cache=False)
|
||||
)
|
||||
# InductorAdaptor (unfortunately) requires AOTAutogradCache
|
||||
# to be turned off to run. It will fail to acquire the hash_str
|
||||
# and error if not.
|
||||
# StandaloneInductorAdaptor (PyTorch 2.8+) fixes this problem.
|
||||
stack.enter_context(
|
||||
torch._functorch.config.patch(enable_autograd_cache=False)
|
||||
)
|
||||
stack.enter_context(
|
||||
torch._functorch.config.patch(enable_remote_autograd_cache=False)
|
||||
)
|
||||
|
||||
with pass_context(runtime_shape):
|
||||
compiled_graph = compile_fx(
|
||||
graph,
|
||||
example_inputs,
|
||||
inner_compile=hijacked_compile_fx_inner,
|
||||
config_patches=current_config,
|
||||
)
|
||||
return compiled_graph, (hash_str, file_path)
|
||||
|
||||
def load(
|
||||
self,
|
||||
handle: Any,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
runtime_shape: Optional[int] = None,
|
||||
) -> Callable:
|
||||
assert isinstance(handle, tuple)
|
||||
assert isinstance(handle[0], str)
|
||||
assert isinstance(handle[1], str)
|
||||
hash_str = handle[0]
|
||||
|
||||
from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
|
||||
from torch._inductor.codecache import FxGraphCache
|
||||
|
||||
with ExitStack() as exit_stack:
|
||||
exit_stack.enter_context(
|
||||
patch(
|
||||
"torch._inductor.codecache.FxGraphCache._get_shape_env",
|
||||
lambda *args, **kwargs: AlwaysHitShapeEnv(),
|
||||
)
|
||||
)
|
||||
# torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
|
||||
if hasattr(AOTAutogradCache, "_get_shape_env"):
|
||||
exit_stack.enter_context(
|
||||
patch(
|
||||
"torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env",
|
||||
lambda *args, **kwargs: AlwaysHitShapeEnv(),
|
||||
)
|
||||
)
|
||||
|
||||
# Dynamo metrics context, see method for more details.
|
||||
exit_stack.enter_context(self.metrics_context())
|
||||
|
||||
if torch.__version__.startswith("2.5"):
|
||||
inductor_compiled_graph = FxGraphCache._lookup_graph(
|
||||
hash_str, example_inputs, True, False
|
||||
)
|
||||
assert inductor_compiled_graph is not None, (
|
||||
"Inductor cache lookup failed. Please remove"
|
||||
f"the cache directory and try again." # noqa
|
||||
)
|
||||
elif torch.__version__ >= "2.6":
|
||||
from torch._inductor.output_code import CompiledFxGraphConstantsWithGm
|
||||
|
||||
constants = CompiledFxGraphConstantsWithGm(graph)
|
||||
inductor_compiled_graph, _ = FxGraphCache._lookup_graph(
|
||||
hash_str, example_inputs, True, None, constants
|
||||
)
|
||||
assert inductor_compiled_graph is not None, (
|
||||
"Inductor cache lookup failed. Please remove"
|
||||
f"the cache directory and try again." # noqa
|
||||
)
|
||||
|
||||
# Inductor calling convention (function signature):
|
||||
# f(list) -> tuple
|
||||
# Dynamo calling convention (function signature):
|
||||
# f(*args) -> Any
|
||||
|
||||
# need to know if the graph returns a tuple
|
||||
from torch._inductor.compile_fx import graph_returns_tuple
|
||||
|
||||
returns_tuple = graph_returns_tuple(graph)
|
||||
|
||||
# this is the callable we return to Dynamo to run
|
||||
def compiled_graph(*args):
|
||||
# convert args to list
|
||||
list_args = list(args)
|
||||
graph_output = inductor_compiled_graph(list_args)
|
||||
# unpack the tuple if needed
|
||||
if returns_tuple:
|
||||
return graph_output
|
||||
else:
|
||||
return graph_output[0]
|
||||
|
||||
return compiled_graph
|
||||
|
||||
def metrics_context(self) -> contextlib.AbstractContextManager:
|
||||
"""
|
||||
This method returns the Dynamo metrics context (if it exists,
|
||||
otherwise a null context). It is used by various compile components.
|
||||
Present in torch>=2.6, it's used inside FxGraphCache in
|
||||
torch==2.6 (but not after). It might also be used in various other
|
||||
torch.compile internal functions.
|
||||
|
||||
Because it is re-entrant, we always set it (even if entering via Dynamo
|
||||
and the context was already entered). We might want to revisit if it
|
||||
should be set at a different level of compilation.
|
||||
|
||||
This is likely a bug in PyTorch: public APIs should not rely on
|
||||
manually setting up internal contexts. But we also rely on non-public
|
||||
APIs which might not provide these guarantees.
|
||||
"""
|
||||
import torch._dynamo.utils
|
||||
|
||||
return torch._dynamo.utils.get_metrics_context()
|
||||
|
||||
|
||||
def set_inductor_config(config, runtime_shape):
|
||||
if isinstance(runtime_shape, int):
|
||||
# for a specific batchsize, tuning triton kernel parameters
|
||||
# can be beneficial
|
||||
config["max_autotune"] = True
|
||||
config["coordinate_descent_tuning"] = True
|
||||
@@ -0,0 +1,230 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/cuda_piecewise_backend.py
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
from contextlib import ExitStack
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
|
||||
import sglang.srt.model_executor.compilation.weak_ref_tensor_jit
|
||||
from sglang.srt.model_executor.compilation.compilation_config import CompilationConfig
|
||||
from sglang.srt.model_executor.compilation.compilation_counter import (
|
||||
compilation_counter,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def weak_ref_tensor(tensor: Any) -> Any:
|
||||
"""
|
||||
Create a weak reference to a tensor.
|
||||
The new tensor will share the same data as the original tensor,
|
||||
but will not keep the original tensor alive.
|
||||
"""
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
# TODO(yuwei): introduce weak_ref_tensor from sgl_kernel
|
||||
return torch.ops.jit_weak_ref_tensor.weak_ref_tensor(tensor)
|
||||
return tensor
|
||||
|
||||
|
||||
def weak_ref_tensors(
|
||||
tensors: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]]
|
||||
) -> Union[torch.Tensor, list[Any], tuple[Any], Any]:
|
||||
"""
|
||||
Convenience function to create weak references to tensors,
|
||||
for single tensor, list of tensors or tuple of tensors.
|
||||
"""
|
||||
if isinstance(tensors, torch.Tensor):
|
||||
return weak_ref_tensor(tensors)
|
||||
if isinstance(tensors, list):
|
||||
return [weak_ref_tensor(t) for t in tensors]
|
||||
if isinstance(tensors, tuple):
|
||||
return tuple(weak_ref_tensor(t) for t in tensors)
|
||||
raise ValueError("Invalid type for tensors")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ConcreteSizeEntry:
|
||||
runtime_shape: int
|
||||
need_to_compile: bool # the size is in compile_sizes
|
||||
use_cudagraph: bool # the size is in cudagraph_capture_sizes
|
||||
|
||||
compiled: bool = False
|
||||
runnable: Callable = None # type: ignore
|
||||
num_finished_warmup: int = 0
|
||||
cudagraph: Optional[torch.cuda.CUDAGraph] = None
|
||||
output: Optional[Any] = None
|
||||
|
||||
# for cudagraph debugging, track the input addresses
|
||||
# during capture, and check if they are the same during replay
|
||||
input_addresses: Optional[list[int]] = None
|
||||
|
||||
|
||||
class CUDAPiecewiseBackend:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph: fx.GraphModule,
|
||||
compile_config: CompilationConfig,
|
||||
inductor_config: dict[str, Any],
|
||||
graph_pool: Any,
|
||||
piecewise_compile_index: int,
|
||||
total_piecewise_compiles: int,
|
||||
sym_shape_indices: list[int],
|
||||
compiled_graph_for_general_shape: Callable,
|
||||
sglang_backend,
|
||||
):
|
||||
"""
|
||||
The backend for piecewise compilation.
|
||||
It mainly handles the compilation and cudagraph capturing.
|
||||
|
||||
We will compile `self.graph` once for the general shape,
|
||||
and then compile for different shapes specified in
|
||||
`compilation_config.compile_sizes`.
|
||||
|
||||
Independently, we will capture cudagraph for different shapes.
|
||||
|
||||
If a shape needs both compilation and cudagraph, we will
|
||||
compile it first, and then capture cudagraph.
|
||||
"""
|
||||
self.graph = graph
|
||||
self.inductor_config = inductor_config
|
||||
self.graph_pool = graph_pool
|
||||
self.piecewise_compile_index = piecewise_compile_index
|
||||
self.total_piecewise_compiles = total_piecewise_compiles
|
||||
self.sglang_backend = sglang_backend
|
||||
|
||||
self.is_first_graph = piecewise_compile_index == 0
|
||||
self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1
|
||||
|
||||
self.compile_sizes: set[int] = set([])
|
||||
self.compile_config = compile_config
|
||||
self.cudagraph_capture_sizes: set[int] = set(compile_config.get_capture_sizes())
|
||||
|
||||
self.first_run_finished = False
|
||||
|
||||
self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa
|
||||
|
||||
self.sym_shape_indices = sym_shape_indices
|
||||
|
||||
self.is_debugging_mode = True
|
||||
|
||||
# the entries for different shapes that we need to either
|
||||
# compile or capture cudagraph
|
||||
self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {}
|
||||
|
||||
# to_be_compiled_sizes tracks the remaining sizes to compile,
|
||||
# and updates during the compilation process, so we need to copy it
|
||||
self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy()
|
||||
for shape in self.compile_sizes.union(self.cudagraph_capture_sizes):
|
||||
self.concrete_size_entries[shape] = ConcreteSizeEntry(
|
||||
runtime_shape=shape,
|
||||
need_to_compile=shape in self.compile_sizes,
|
||||
use_cudagraph=shape in self.cudagraph_capture_sizes,
|
||||
)
|
||||
|
||||
def check_for_ending_compilation(self):
|
||||
if self.is_last_graph and not self.to_be_compiled_sizes:
|
||||
# no specific sizes to compile
|
||||
# save the hash of the inductor graph for the next run
|
||||
self.sglang_backend.compiler_manager.save_to_file()
|
||||
|
||||
def __call__(self, *args) -> Any:
|
||||
if not self.first_run_finished:
|
||||
self.first_run_finished = True
|
||||
self.check_for_ending_compilation()
|
||||
return self.compiled_graph_for_general_shape(*args)
|
||||
runtime_shape = args[self.sym_shape_indices[0]]
|
||||
if runtime_shape not in self.concrete_size_entries:
|
||||
# we don't need to do anything for this shape
|
||||
return self.compiled_graph_for_general_shape(*args)
|
||||
|
||||
entry = self.concrete_size_entries[runtime_shape]
|
||||
|
||||
if entry.runnable is None:
|
||||
entry.runnable = self.compiled_graph_for_general_shape
|
||||
|
||||
if entry.need_to_compile and not entry.compiled:
|
||||
entry.compiled = True
|
||||
self.to_be_compiled_sizes.remove(runtime_shape)
|
||||
# args are real arguments
|
||||
entry.runnable = self.sglang_backend.compiler_manager.compile(
|
||||
self.graph,
|
||||
args,
|
||||
self.inductor_config,
|
||||
graph_index=self.piecewise_compile_index,
|
||||
num_graphs=self.total_piecewise_compiles,
|
||||
runtime_shape=runtime_shape,
|
||||
)
|
||||
|
||||
# finished compilations for all required shapes
|
||||
if self.is_last_graph and not self.to_be_compiled_sizes:
|
||||
self.check_for_ending_compilation()
|
||||
|
||||
# Skip CUDA graphs if this entry doesn't use them OR
|
||||
# if we're supposed to skip them globally
|
||||
# skip_cuda_graphs = get_forward_context().skip_cuda_graphs
|
||||
# if not entry.use_cudagraph or skip_cuda_graphs:
|
||||
# return entry.runnable(*args)
|
||||
|
||||
if entry.cudagraph is None:
|
||||
if entry.num_finished_warmup < 1: # noqa
|
||||
entry.num_finished_warmup += 1
|
||||
return entry.runnable(*args)
|
||||
|
||||
input_addresses = [
|
||||
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
||||
]
|
||||
entry.input_addresses = input_addresses
|
||||
cudagraph = torch.cuda.CUDAGraph()
|
||||
|
||||
with ExitStack() as stack:
|
||||
if not self.is_first_graph:
|
||||
# during every model forward, we will capture
|
||||
# many pieces of cudagraphs (roughly one per layer).
|
||||
# running gc again and again across layers will
|
||||
# make the cudagraph capture very slow.
|
||||
# therefore, we only run gc for the first graph,
|
||||
# and disable gc for the rest of the graphs.
|
||||
stack.enter_context(patch("gc.collect", lambda: None))
|
||||
stack.enter_context(patch("torch.cuda.empty_cache", lambda: None))
|
||||
|
||||
# mind-exploding: carefully manage the reference and memory.
|
||||
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
|
||||
# `output` is managed by pytorch's cudagraph pool
|
||||
output = entry.runnable(*args)
|
||||
if self.is_last_graph:
|
||||
# by converting it to weak ref,
|
||||
# the original `output` will immediately be released
|
||||
# to save memory. It is only safe to do this for
|
||||
# the last graph, because the output of the last graph
|
||||
# will not be used by any other cuda graph.
|
||||
output = weak_ref_tensors(output)
|
||||
|
||||
# here we always use weak ref for the output
|
||||
# to save memory
|
||||
entry.output = weak_ref_tensors(output)
|
||||
entry.cudagraph = cudagraph
|
||||
|
||||
compilation_counter.num_cudagraph_captured += 1
|
||||
|
||||
# important: we need to return the output, rather than
|
||||
# the weak ref of the output, so that pytorch can correctly
|
||||
# manage the memory during cuda graph capture
|
||||
return output
|
||||
|
||||
if self.is_debugging_mode:
|
||||
# check if the input addresses are the same
|
||||
new_input_addresses = [
|
||||
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
||||
]
|
||||
assert new_input_addresses == entry.input_addresses, (
|
||||
"Input addresses for cudagraphs are different during replay."
|
||||
f" Expected {entry.input_addresses}, got {new_input_addresses}"
|
||||
)
|
||||
|
||||
entry.cudagraph.replay()
|
||||
return entry.output
|
||||
@@ -0,0 +1,134 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/fix_functionalization.py
|
||||
|
||||
import logging
|
||||
import operator
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
|
||||
from sglang.srt.model_executor.compilation.fx_utils import is_func
|
||||
from sglang.srt.model_executor.compilation.inductor_pass import SGLangInductorPass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FixFunctionalizationPass(SGLangInductorPass):
|
||||
"""
|
||||
This pass defunctionalizes certain nodes to avoid redundant tensor copies.
|
||||
After this pass, DCE (dead-code elimination) should never be run,
|
||||
as de-functionalized nodes may appear as dead code.
|
||||
|
||||
To add new nodes to defunctionalize, add to the if-elif chain in __call__.
|
||||
"""
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
self.begin()
|
||||
self.dump_graph(graph, "before_fix_functionalization")
|
||||
|
||||
self.nodes_to_remove: list[torch.fx.Node] = []
|
||||
count = 0
|
||||
for node in graph.nodes:
|
||||
if not is_func(node, auto_functionalized):
|
||||
continue # Avoid deep if-elif nesting
|
||||
count += 1
|
||||
|
||||
self.dump_graph(graph, "before_fix_functionalization_cleanup")
|
||||
|
||||
# Remove the nodes all at once
|
||||
count_removed = len(self.nodes_to_remove)
|
||||
for node in self.nodes_to_remove:
|
||||
graph.erase_node(node)
|
||||
|
||||
logger.debug(
|
||||
"De-functionalized %s nodes, removed %s nodes", count, count_removed
|
||||
)
|
||||
self.dump_graph(graph, "after_fix_functionalization")
|
||||
self.end_and_log()
|
||||
|
||||
def _remove(self, node_or_nodes: Union[torch.fx.Node, Iterable[torch.fx.Node]]):
|
||||
"""
|
||||
Stage a node (or nodes) for removal at the end of the pass.
|
||||
"""
|
||||
if isinstance(node_or_nodes, torch.fx.Node):
|
||||
self.nodes_to_remove.append(node_or_nodes)
|
||||
else:
|
||||
self.nodes_to_remove.extend(node_or_nodes)
|
||||
|
||||
def defunctionalize(
|
||||
self,
|
||||
graph: torch.fx.Graph,
|
||||
node: torch.fx.Node,
|
||||
mutated_args: dict[int, Union[torch.fx.Node, str]],
|
||||
args: Optional[tuple[Union[torch.fx.Node, str], ...]] = None,
|
||||
):
|
||||
"""
|
||||
De-functionalize a node by replacing it with a call to the original.
|
||||
It also replaces the getitem users with the mutated arguments.
|
||||
See replace_users_with_mutated_args and insert_defunctionalized.
|
||||
"""
|
||||
self.replace_users_with_mutated_args(node, mutated_args)
|
||||
self.insert_defunctionalized(graph, node, args=args)
|
||||
self._remove(node)
|
||||
|
||||
def replace_users_with_mutated_args(
|
||||
self, node: torch.fx.Node, mutated_args: dict[int, Union[torch.fx.Node, str]]
|
||||
):
|
||||
"""
|
||||
Replace all getitem users of the auto-functionalized node with the
|
||||
mutated arguments.
|
||||
:param node: The auto-functionalized node
|
||||
:param mutated_args: The mutated arguments, indexed by getitem index.
|
||||
If the value of an arg is a string, `node.kwargs[arg]` is used.
|
||||
"""
|
||||
for idx, user in self.getitem_users(node).items():
|
||||
arg = mutated_args[idx]
|
||||
arg = node.kwargs[arg] if isinstance(arg, str) else arg
|
||||
user.replace_all_uses_with(arg)
|
||||
self._remove(user)
|
||||
|
||||
def getitem_users(self, node: torch.fx.Node) -> dict[int, torch.fx.Node]:
|
||||
"""
|
||||
Returns the operator.getitem users of the auto-functionalized node,
|
||||
indexed by the index they are getting.
|
||||
"""
|
||||
users = {}
|
||||
for user in node.users:
|
||||
if is_func(user, operator.getitem):
|
||||
idx = user.args[1]
|
||||
users[idx] = user
|
||||
return users
|
||||
|
||||
def insert_defunctionalized(
|
||||
self,
|
||||
graph: torch.fx.Graph,
|
||||
node: torch.fx.Node,
|
||||
args: Optional[tuple[Union[torch.fx.Node, str], ...]] = None,
|
||||
):
|
||||
"""
|
||||
Insert a new defunctionalized node into the graph before node.
|
||||
If one of the kwargs is 'out', provide args directly,
|
||||
as node.kwargs cannot be used.
|
||||
See https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351
|
||||
|
||||
:param graph: Graph to insert the defunctionalized node into
|
||||
:param node: The auto-functionalized node to defunctionalize
|
||||
:param args: If we cannot use kwargs, specify args directly.
|
||||
If an arg is a string, `node.kwargs[arg]` is used.
|
||||
""" # noqa: E501
|
||||
assert is_func(
|
||||
node, auto_functionalized
|
||||
), f"node must be auto-functionalized, is {node} instead"
|
||||
|
||||
# Create a new call to the original function
|
||||
with graph.inserting_before(node):
|
||||
function = node.args[0]
|
||||
if args is None:
|
||||
graph.call_function(function, kwargs=node.kwargs)
|
||||
else:
|
||||
# Args passed as strings refer to items in node.kwargs
|
||||
args = tuple(
|
||||
node.kwargs[arg] if isinstance(arg, str) else arg for arg in args
|
||||
)
|
||||
graph.call_function(function, args=args)
|
||||
83
python/sglang/srt/model_executor/compilation/fx_utils.py
Normal file
83
python/sglang/srt/model_executor/compilation/fx_utils.py
Normal file
@@ -0,0 +1,83 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/fx_utils.py
|
||||
|
||||
import operator
|
||||
from collections.abc import Iterable, Iterator
|
||||
from typing import Optional
|
||||
|
||||
from torch import fx
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._ops import OpOverload
|
||||
|
||||
|
||||
def is_func(node: fx.Node, target) -> bool:
|
||||
return node.op == "call_function" and node.target == target
|
||||
|
||||
|
||||
def is_auto_func(node: fx.Node, op: OpOverload) -> bool:
|
||||
return is_func(node, auto_functionalized) and node.args[0] == op
|
||||
|
||||
|
||||
# Returns the first specified node with the given op (if it exists)
|
||||
def find_specified_fn_maybe(
|
||||
nodes: Iterable[fx.Node], op: OpOverload
|
||||
) -> Optional[fx.Node]:
|
||||
for node in nodes:
|
||||
if node.target == op:
|
||||
return node
|
||||
return None
|
||||
|
||||
|
||||
# Returns the first specified node with the given op
|
||||
def find_specified_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node:
|
||||
node = find_specified_fn_maybe(nodes, op)
|
||||
assert node is not None, f"Could not find {op} in nodes {nodes}"
|
||||
return node
|
||||
|
||||
|
||||
# Returns the first auto_functionalized node with the given op (if it exists)
|
||||
def find_auto_fn_maybe(nodes: Iterable[fx.Node], op: OpOverload) -> Optional[fx.Node]:
|
||||
for node in nodes:
|
||||
if is_func(node, auto_functionalized) and node.args[0] == op: # noqa
|
||||
return node
|
||||
return None
|
||||
|
||||
|
||||
# Returns the first auto_functionalized node with the given op
|
||||
def find_auto_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node:
|
||||
node = find_auto_fn_maybe(nodes, op)
|
||||
assert node is not None, f"Could not find {op} in nodes {nodes}"
|
||||
return node
|
||||
|
||||
|
||||
# Returns the getitem node that extracts the idx-th element from node
|
||||
# (if it exists)
|
||||
def find_getitem_maybe(node: fx.Node, idx: int) -> Optional[fx.Node]:
|
||||
for user in node.users:
|
||||
if is_func(user, operator.getitem) and user.args[1] == idx:
|
||||
return user
|
||||
return None
|
||||
|
||||
|
||||
# Returns the getitem node that extracts the idx-th element from node
|
||||
def find_getitem(node: fx.Node, idx: int) -> fx.Node:
|
||||
ret = find_getitem_maybe(node, idx)
|
||||
assert ret is not None, f"Could not find getitem {idx} in node {node}"
|
||||
return ret
|
||||
|
||||
|
||||
# An auto-functionalization-aware utility for finding nodes with a specific op
|
||||
def find_op_nodes(op: OpOverload, graph: fx.Graph) -> Iterator[fx.Node]:
|
||||
if not op._schema.is_mutable:
|
||||
yield from graph.find_nodes(op="call_function", target=op)
|
||||
|
||||
for n in graph.find_nodes(op="call_function", target=auto_functionalized):
|
||||
if n.args[0] == op:
|
||||
yield n
|
||||
|
||||
|
||||
# Asserts that the node only has one user and returns it
|
||||
# Even if a node has only 1 user, it might share storage with another node,
|
||||
# which might need to be taken into account.
|
||||
def get_only_user(node: fx.Node) -> fx.Node:
|
||||
assert len(node.users) == 1
|
||||
return next(iter(node.users))
|
||||
140
python/sglang/srt/model_executor/compilation/inductor_pass.py
Normal file
140
python/sglang/srt/model_executor/compilation/inductor_pass.py
Normal file
@@ -0,0 +1,140 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/inductor_pass.py
|
||||
|
||||
import hashlib
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import types
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import fx
|
||||
from torch._dynamo.utils import lazy_format_graph_code
|
||||
from torch._inductor.custom_graph_pass import CustomGraphPass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_pass_context = None
|
||||
|
||||
|
||||
class PassContext:
|
||||
|
||||
def __init__(self, runtime_shape: Optional[int]):
|
||||
self.runtime_shape = runtime_shape
|
||||
|
||||
|
||||
def get_pass_context() -> PassContext:
|
||||
"""Get the current pass context."""
|
||||
assert _pass_context is not None
|
||||
return _pass_context
|
||||
|
||||
|
||||
@contextmanager
|
||||
def pass_context(runtime_shape: Optional[int]):
|
||||
"""A context manager that stores the current pass context,
|
||||
usually it is a list of sizes to specialize.
|
||||
"""
|
||||
global _pass_context
|
||||
prev_context = _pass_context
|
||||
_pass_context = PassContext(runtime_shape)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_pass_context = prev_context
|
||||
|
||||
|
||||
class InductorPass(CustomGraphPass):
|
||||
"""
|
||||
A custom graph pass that uses a hash of its source as the UUID.
|
||||
This is defined as a convenience and should work in most cases.
|
||||
"""
|
||||
|
||||
def uuid(self) -> Any:
|
||||
"""
|
||||
Provide a unique identifier for the pass, used in Inductor code cache.
|
||||
This should depend on the pass implementation, so that changes to the
|
||||
pass result in recompilation.
|
||||
By default, the object source is hashed.
|
||||
"""
|
||||
return InductorPass.hash_source(self)
|
||||
|
||||
@staticmethod
|
||||
def hash_source(*srcs: Union[str, Any]):
|
||||
"""
|
||||
Utility method to hash the sources of functions or objects.
|
||||
:param srcs: strings or objects to add to the hash.
|
||||
Objects and functions have their source inspected.
|
||||
:return:
|
||||
"""
|
||||
hasher = hashlib.sha256()
|
||||
for src in srcs:
|
||||
if isinstance(src, str):
|
||||
src_str = src
|
||||
elif isinstance(src, types.FunctionType):
|
||||
src_str = inspect.getsource(src)
|
||||
else:
|
||||
src_str = inspect.getsource(src.__class__)
|
||||
hasher.update(src_str.encode("utf-8"))
|
||||
return hasher.hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def hash_dict(dict_: dict[Any, Any]):
|
||||
"""
|
||||
Utility method to hash a dictionary, can alternatively be used for uuid.
|
||||
:return: A sha256 hash of the json rep of the dictionary.
|
||||
"""
|
||||
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
|
||||
return hashlib.sha256(encoded).hexdigest()
|
||||
|
||||
def is_applicable_for_shape(self, shape: Optional[int]):
|
||||
return True
|
||||
|
||||
|
||||
class CallableInductorPass(InductorPass):
|
||||
"""
|
||||
This class is a wrapper for a callable that automatically provides an
|
||||
implementation of the UUID.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, callable: Callable[[fx.Graph], None], uuid: Optional[Any] = None
|
||||
):
|
||||
self.callable = callable
|
||||
self._uuid = self.hash_source(callable) if uuid is None else uuid
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
self.callable(graph)
|
||||
|
||||
def uuid(self) -> Any:
|
||||
return self._uuid
|
||||
|
||||
|
||||
class SGLangInductorPass(InductorPass):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
):
|
||||
self.pass_name = self.__class__.__name__
|
||||
|
||||
def dump_graph(self, graph: torch.fx.Graph, stage: str):
|
||||
lazy_format_graph_code(stage, graph.owning_module)
|
||||
|
||||
def begin(self):
|
||||
self._start_time = time.perf_counter_ns()
|
||||
|
||||
def end_and_log(self):
|
||||
self._end_time = time.perf_counter_ns()
|
||||
duration_ms = float(self._end_time - self._start_time) / 1.0e6
|
||||
logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms)
|
||||
|
||||
|
||||
class PrinterInductorPass(SGLangInductorPass):
|
||||
|
||||
def __init__(self, name: str):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
self.dump_graph(graph, self.name)
|
||||
68
python/sglang/srt/model_executor/compilation/pass_manager.py
Normal file
68
python/sglang/srt/model_executor/compilation/pass_manager.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/pass_manager.py
|
||||
|
||||
import logging
|
||||
|
||||
from torch import fx as fx
|
||||
|
||||
from sglang.srt.model_executor.compilation.fix_functionalization import (
|
||||
FixFunctionalizationPass,
|
||||
)
|
||||
from sglang.srt.model_executor.compilation.inductor_pass import (
|
||||
CustomGraphPass,
|
||||
InductorPass,
|
||||
SGLangInductorPass,
|
||||
get_pass_context,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PostGradPassManager(CustomGraphPass):
|
||||
"""
|
||||
The pass manager for post-grad passes.
|
||||
It handles configuration, adding custom passes, and running passes.
|
||||
It supports uuid for the Inductor code cache. That includes torch<2.6
|
||||
support using pickling (in .inductor_pass.CustomGraphPass).
|
||||
|
||||
The order of the post-grad post-passes is:
|
||||
1. passes (constructor parameter)
|
||||
2. default passes (NoopEliminationPass, FusionPass)
|
||||
3. config["post_grad_custom_post_pass"] (if it exists)
|
||||
4. fix_functionalization
|
||||
This way, all passes operate on a functionalized graph.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.passes: list[SGLangInductorPass] = []
|
||||
|
||||
def __call__(self, graph: fx.Graph):
|
||||
shape = get_pass_context().runtime_shape
|
||||
for pass_ in self.passes:
|
||||
if pass_.is_applicable_for_shape(shape):
|
||||
pass_(graph)
|
||||
|
||||
# always run fix_functionalization last
|
||||
self.fix_functionalization(graph)
|
||||
|
||||
def configure(
|
||||
self,
|
||||
):
|
||||
self.pass_config = dict()
|
||||
self.fix_functionalization = FixFunctionalizationPass()
|
||||
|
||||
def add(self, pass_: InductorPass):
|
||||
assert isinstance(pass_, InductorPass)
|
||||
self.passes.append(pass_)
|
||||
|
||||
def uuid(self):
|
||||
"""
|
||||
The PostGradPassManager is set as a custom pass in the Inductor and
|
||||
affects compilation caching. Its uuid depends on the UUIDs of all
|
||||
dependent passes and the pass config. See InductorPass for more info.
|
||||
"""
|
||||
pass_manager_uuid = "fshdakhsa"
|
||||
state = {"pass_config": pass_manager_uuid, "passes": []}
|
||||
for pass_ in self.passes:
|
||||
state["passes"].append(pass_.uuid())
|
||||
state["passes"].append(self.fix_functionalization.uuid())
|
||||
return InductorPass.hash_dict(state)
|
||||
@@ -0,0 +1,40 @@
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
|
||||
@dataclass
|
||||
class ForwardContext:
|
||||
def __init__(self):
|
||||
self.forward_batch = None
|
||||
self.attention_layer = None
|
||||
|
||||
def set_forward_batch(self, forward_batch: ForwardBatch):
|
||||
self.forward_batch = forward_batch
|
||||
|
||||
def set_attention_layers(self, layers: List[Any]):
|
||||
self.attention_layers = layers
|
||||
|
||||
|
||||
_forward_context: Optional[ForwardContext] = None
|
||||
|
||||
|
||||
def get_forward_context() -> Optional[ForwardContext]:
|
||||
if _forward_context is None:
|
||||
return None
|
||||
return _forward_context
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_forward_context(forward_batch: ForwardBatch, attention_layers: List[Any]):
|
||||
global _forward_context
|
||||
prev_forward_context = _forward_context
|
||||
_forward_context = ForwardContext()
|
||||
_forward_context.set_forward_batch(forward_batch)
|
||||
_forward_context.set_attention_layers(attention_layers)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_forward_context = prev_forward_context
|
||||
@@ -0,0 +1,28 @@
|
||||
// Adapted from: https://github.com/vllm-project/vllm/blob/main/csrc/ops.h
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <vector>
|
||||
|
||||
static at::Tensor weak_ref_tensor(at::Tensor &tensor) {
|
||||
TORCH_CHECK(tensor.is_cuda(), "weak_ref_tensor expects a CUDA tensor");
|
||||
|
||||
void *data_ptr = tensor.data_ptr();
|
||||
std::vector<int64_t> sizes = tensor.sizes().vec();
|
||||
std::vector<int64_t> strides = tensor.strides().vec();
|
||||
|
||||
auto options = tensor.options();
|
||||
|
||||
auto new_tensor = torch::from_blob(data_ptr, sizes, strides, options);
|
||||
|
||||
return new_tensor;
|
||||
}
|
||||
|
||||
TORCH_LIBRARY(jit_weak_ref_tensor, ops) {
|
||||
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(jit_weak_ref_tensor, CUDA, ops) {
|
||||
ops.impl("weak_ref_tensor", weak_ref_tensor);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {}
|
||||
@@ -0,0 +1,16 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
_abs_path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
load(
|
||||
name="weak_ref_tensor_ext",
|
||||
sources=[f"{_abs_path}/weak_ref_tensor.cpp"],
|
||||
extra_cflags=["-O3"],
|
||||
)
|
||||
|
||||
x = torch.arange(12, device="cuda").reshape(3, 4)
|
||||
y = torch.ops.jit_weak_ref_tensor.weak_ref_tensor(x)
|
||||
print("alias:", x.data_ptr() == y.data_ptr())
|
||||
@@ -108,8 +108,15 @@ from sglang.srt.mem_cache.memory_pool import (
|
||||
)
|
||||
from sglang.srt.model_executor.cpu_graph_runner import CPUGraphRunner
|
||||
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
ForwardBatch,
|
||||
ForwardMode,
|
||||
PPProxyTensors,
|
||||
)
|
||||
from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner
|
||||
from sglang.srt.model_executor.piecewise_cuda_graph_runner import (
|
||||
PiecewiseCudaGraphRunner,
|
||||
)
|
||||
from sglang.srt.model_loader import get_model
|
||||
from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader
|
||||
from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
|
||||
@@ -307,6 +314,26 @@ class ModelRunner:
|
||||
self._model_update_group = {}
|
||||
self._weights_send_group = {}
|
||||
|
||||
if (
|
||||
self.server_args.enable_piecewise_cuda_graph
|
||||
and self.can_run_piecewise_cuda_graph()
|
||||
):
|
||||
self.attention_layers = []
|
||||
for layer in self.model.model.layers:
|
||||
if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "attn"):
|
||||
self.attention_layers.append(layer.self_attn.attn)
|
||||
if len(self.attention_layers) < self.model_config.num_hidden_layers:
|
||||
# TODO(yuwei): support Non-Standard GQA
|
||||
log_info_on_rank0(
|
||||
logger,
|
||||
"Disable piecewise CUDA graph because some layers do not apply Standard GQA",
|
||||
)
|
||||
self.piecewise_cuda_graph_runner = None
|
||||
else:
|
||||
self.piecewise_cuda_graph_runner = PiecewiseCudaGraphRunner(self)
|
||||
else:
|
||||
self.piecewise_cuda_graph_runner = None
|
||||
|
||||
def initialize(self, min_per_gpu_memory: float):
|
||||
server_args = self.server_args
|
||||
|
||||
@@ -692,6 +719,7 @@ class ModelRunner:
|
||||
pipeline_model_parallel_size=self.pp_size,
|
||||
expert_model_parallel_size=self.moe_ep_size,
|
||||
duplicate_tp_group=self.server_args.enable_pdmux,
|
||||
torch_compile=self.server_args.enable_piecewise_cuda_graph,
|
||||
)
|
||||
initialize_dp_attention(
|
||||
server_args=self.server_args,
|
||||
@@ -1411,6 +1439,27 @@ class ModelRunner:
|
||||
f"Use Sliding window memory pool. full_layer_tokens={self.full_max_total_num_tokens}, swa_layer_tokens={self.swa_max_total_num_tokens}"
|
||||
)
|
||||
|
||||
def can_run_piecewise_cuda_graph(self):
|
||||
if self.server_args.disable_cuda_graph:
|
||||
log_info_on_rank0(
|
||||
logger, "Disable piecewise CUDA graph because disable_cuda_graph is set"
|
||||
)
|
||||
return False
|
||||
if self.server_args.enable_torch_compile:
|
||||
log_info_on_rank0(
|
||||
logger,
|
||||
"Disable piecewise CUDA graph because piecewise_cuda_graph has conflict with torch compile",
|
||||
)
|
||||
return False
|
||||
if self.pp_size > 1:
|
||||
# TODO(yuwei): support PP
|
||||
log_info_on_rank0(
|
||||
logger,
|
||||
"Disable piecewise CUDA graph because piecewise_cuda_graph does not support PP",
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
def init_memory_pool(
|
||||
self,
|
||||
total_gpu_memory: int,
|
||||
@@ -1932,6 +1981,11 @@ class ModelRunner:
|
||||
kwargs["input_embeds"] = forward_batch.input_embeds.bfloat16()
|
||||
if not self.is_generation:
|
||||
kwargs["get_embedding"] = True
|
||||
|
||||
if self.piecewise_cuda_graph_runner is not None:
|
||||
if self.piecewise_cuda_graph_runner.can_run(forward_batch):
|
||||
return self.piecewise_cuda_graph_runner.replay(forward_batch, **kwargs)
|
||||
|
||||
return self.model.forward(
|
||||
forward_batch.input_ids,
|
||||
forward_batch.positions,
|
||||
|
||||
532
python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py
Normal file
532
python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py
Normal file
@@ -0,0 +1,532 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Run the model with cuda graph and torch.compile."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import bisect
|
||||
import gc
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
from sglang.srt.custom_op import CustomOp
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
|
||||
set_graph_pool_id,
|
||||
)
|
||||
from sglang.srt.distributed.parallel_state import graph_capture
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
DpPaddingMode,
|
||||
get_attention_tp_rank,
|
||||
get_attention_tp_size,
|
||||
set_dp_buffer_len,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.torchao_utils import save_gemlite_cache
|
||||
from sglang.srt.model_executor.compilation.compilation_config import CompilationConfig
|
||||
from sglang.srt.model_executor.compilation.compile import (
|
||||
install_torch_compiled,
|
||||
set_compiled,
|
||||
)
|
||||
from sglang.srt.model_executor.compilation.piecewise_context_manager import (
|
||||
set_forward_context,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
CaptureHiddenMode,
|
||||
ForwardBatch,
|
||||
ForwardMode,
|
||||
PPProxyTensors,
|
||||
)
|
||||
from sglang.srt.two_batch_overlap import TboCudaGraphRunnerPlugin
|
||||
from sglang.srt.utils import get_available_gpu_memory, log_info_on_rank0
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
|
||||
# Detect whether the current forward pass is in capture mode
|
||||
is_capture_mode = False
|
||||
|
||||
|
||||
def get_is_capture_mode():
|
||||
return is_capture_mode
|
||||
|
||||
|
||||
@contextmanager
|
||||
def model_capture_mode():
|
||||
global is_capture_mode
|
||||
is_capture_mode = True
|
||||
|
||||
yield
|
||||
|
||||
is_capture_mode = False
|
||||
|
||||
|
||||
@contextmanager
|
||||
def freeze_gc(enable_cudagraph_gc: bool):
|
||||
"""
|
||||
Optimize garbage collection during CUDA graph capture.
|
||||
Clean up, then freeze all remaining objects from being included
|
||||
in future collections if GC is disabled during capture.
|
||||
"""
|
||||
gc.collect()
|
||||
should_freeze = not enable_cudagraph_gc
|
||||
if should_freeze:
|
||||
gc.freeze()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if should_freeze:
|
||||
gc.unfreeze()
|
||||
|
||||
|
||||
def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
|
||||
for sub in model._modules.values():
|
||||
if isinstance(sub, CustomOp):
|
||||
if reverse:
|
||||
sub.leave_torch_compile()
|
||||
else:
|
||||
sub.enter_torch_compile(num_tokens=num_tokens)
|
||||
if isinstance(sub, torch.nn.Module):
|
||||
_to_torch(sub, reverse, num_tokens)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def patch_model(model: torch.nn.Module):
|
||||
try:
|
||||
_to_torch(model, reverse=False, num_tokens=16)
|
||||
yield model
|
||||
finally:
|
||||
_to_torch(model, reverse=True, num_tokens=16)
|
||||
|
||||
|
||||
# Reuse this memory pool across all cuda graph runners.
|
||||
global_graph_memory_pool = None
|
||||
|
||||
|
||||
def get_global_graph_memory_pool():
|
||||
return global_graph_memory_pool
|
||||
|
||||
|
||||
def set_global_graph_memory_pool(val):
|
||||
global global_graph_memory_pool
|
||||
global_graph_memory_pool = val
|
||||
|
||||
|
||||
class PiecewiseCudaGraphRunner:
|
||||
"""A PiecewiseCudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
|
||||
|
||||
def __init__(self, model_runner: ModelRunner):
|
||||
# Parse args
|
||||
self.model_runner = model_runner
|
||||
self.device = model_runner.device
|
||||
self.device_module = torch.get_device_module(self.device)
|
||||
self.graphs = {}
|
||||
self.output_buffers = {}
|
||||
self.tp_size = model_runner.server_args.tp_size
|
||||
self.dp_size = model_runner.server_args.dp_size
|
||||
self.pp_size = model_runner.server_args.pp_size
|
||||
|
||||
self.attn_tp_size = get_attention_tp_size()
|
||||
self.attn_tp_rank = get_attention_tp_rank()
|
||||
|
||||
assert (
|
||||
self.model_runner.server_args.piecewise_cuda_graph_tokens is not None
|
||||
), "piecewise_cuda_graph_tokens is not set"
|
||||
self.compile_config = CompilationConfig(
|
||||
self.model_runner.server_args.piecewise_cuda_graph_tokens
|
||||
)
|
||||
|
||||
# Batch sizes to capture
|
||||
self.capture_num_tokens = self.compile_config.get_capture_sizes()
|
||||
log_info_on_rank0(
|
||||
logger, f"Capture cuda graph num tokens {self.capture_num_tokens}"
|
||||
)
|
||||
self.capture_forward_mode = ForwardMode.EXTEND
|
||||
self.capture_hidden_mode = CaptureHiddenMode.NULL
|
||||
|
||||
# If returning hidden states is enabled, set initial capture hidden mode to full to avoid double-capture on startup
|
||||
if model_runner.server_args.enable_return_hidden_states:
|
||||
self.capture_hidden_mode = CaptureHiddenMode.FULL
|
||||
|
||||
# Attention backend
|
||||
self.max_num_tokens = max(self.capture_num_tokens)
|
||||
|
||||
# Graph inputs
|
||||
with torch.device(self.device):
|
||||
self.input_ids = torch.zeros((self.max_num_tokens,), dtype=torch.int64)
|
||||
self.out_cache_loc = torch.zeros(
|
||||
(self.max_num_tokens,), dtype=self._cache_loc_dtype()
|
||||
)
|
||||
self.positions = torch.zeros((self.max_num_tokens,), dtype=torch.int64)
|
||||
self.tbo_plugin = TboCudaGraphRunnerPlugin()
|
||||
|
||||
self.attention_layers = self.model_runner.attention_layers
|
||||
|
||||
if get_global_graph_memory_pool() is None:
|
||||
set_global_graph_memory_pool(self.device_module.graph_pool_handle())
|
||||
# Set graph pool id globally to be able to use symmetric memory
|
||||
set_graph_pool_id(get_global_graph_memory_pool())
|
||||
|
||||
with patch_model(self.model_runner.model.model) as patched_model:
|
||||
install_torch_compiled(
|
||||
patched_model,
|
||||
fullgraph=True,
|
||||
dynamic_arg_dims=None,
|
||||
compile_config=self.compile_config,
|
||||
graph_pool=get_global_graph_memory_pool(),
|
||||
)
|
||||
|
||||
with set_compiled(True):
|
||||
self.warmup_and_capture()
|
||||
|
||||
# Capture
|
||||
try:
|
||||
with model_capture_mode():
|
||||
self.capture()
|
||||
except RuntimeError as e:
|
||||
raise Exception(
|
||||
f"Capture cuda graph failed: {e}\n{PIECEWISE_CUDA_GRAPH_CAPTURE_FAILED_MSG}"
|
||||
)
|
||||
|
||||
self.raw_num_tokens = 0
|
||||
|
||||
def warmup_and_capture(self):
|
||||
num_tokens = 2
|
||||
with torch.device(self.device):
|
||||
forward_batch = ForwardBatch(
|
||||
forward_mode=ForwardMode.EXTEND,
|
||||
batch_size=1,
|
||||
input_ids=torch.randint(0, 100, (num_tokens,), device=self.device),
|
||||
req_pool_indices=torch.arange(1, device=self.device),
|
||||
seq_lens=torch.tensor([num_tokens], device=self.device),
|
||||
next_token_logits_buffer=None,
|
||||
orig_seq_lens=torch.tensor([num_tokens], device=self.device),
|
||||
seq_lens_cpu=torch.tensor([num_tokens]),
|
||||
req_to_token_pool=self.model_runner.req_to_token_pool,
|
||||
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
||||
attn_backend=self.model_runner.attn_backend,
|
||||
out_cache_loc=torch.randint(0, 100, (num_tokens,), device=self.device),
|
||||
seq_lens_sum=num_tokens,
|
||||
encoder_lens=None,
|
||||
return_logprob=False,
|
||||
extend_seq_lens=torch.tensor([num_tokens], device=self.device),
|
||||
extend_prefix_lens=torch.tensor([num_tokens], device=self.device),
|
||||
extend_start_loc=torch.tensor([0], device=self.device),
|
||||
extend_prefix_lens_cpu=torch.tensor([num_tokens]),
|
||||
extend_seq_lens_cpu=torch.tensor([num_tokens]),
|
||||
extend_logprob_start_lens_cpu=torch.tensor([num_tokens]),
|
||||
positions=torch.arange(num_tokens, device=self.device),
|
||||
global_num_tokens_gpu=None,
|
||||
global_num_tokens_for_logprob_gpu=None,
|
||||
dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
|
||||
global_dp_buffer_len=None,
|
||||
mrope_positions=None,
|
||||
spec_algorithm=None,
|
||||
spec_info=None,
|
||||
capture_hidden_mode=CaptureHiddenMode.NULL,
|
||||
num_token_non_padded=None,
|
||||
global_forward_mode=ForwardMode.EXTEND,
|
||||
lora_ids=None,
|
||||
)
|
||||
|
||||
with set_forward_context(forward_batch, self.attention_layers):
|
||||
_ = self.model_runner.model.forward(
|
||||
forward_batch.input_ids,
|
||||
forward_batch.positions,
|
||||
forward_batch,
|
||||
)
|
||||
|
||||
def _cache_loc_dtype(self):
|
||||
return torch.int64
|
||||
|
||||
def can_run(self, forward_batch: ForwardBatch):
|
||||
num_tokens = len(forward_batch.input_ids)
|
||||
# TODO(yuwei): support return logprob
|
||||
if forward_batch.return_logprob:
|
||||
return False
|
||||
if num_tokens <= self.max_num_tokens:
|
||||
return True
|
||||
return False
|
||||
|
||||
def capture(self) -> None:
|
||||
# Trigger CUDA graph capture for specific shapes.
|
||||
# Capture the large shapes first so that the smaller shapes
|
||||
# can reuse the memory pool allocated for the large shapes.
|
||||
with freeze_gc(
|
||||
self.model_runner.server_args.enable_cudagraph_gc
|
||||
), graph_capture() as graph_capture_context:
|
||||
self.stream = graph_capture_context.stream
|
||||
avail_mem = get_available_gpu_memory(
|
||||
self.model_runner.device,
|
||||
self.model_runner.gpu_id,
|
||||
empty_cache=False,
|
||||
)
|
||||
# Reverse the order to enable better memory sharing across cuda graphs.
|
||||
capture_range = (
|
||||
tqdm.tqdm(list(reversed(self.capture_num_tokens)))
|
||||
if get_tensor_model_parallel_rank() == 0
|
||||
else reversed(self.capture_num_tokens)
|
||||
)
|
||||
for i, num_tokens in enumerate(capture_range):
|
||||
if get_tensor_model_parallel_rank() == 0:
|
||||
avail_mem = get_available_gpu_memory(
|
||||
self.model_runner.device,
|
||||
self.model_runner.gpu_id,
|
||||
empty_cache=False,
|
||||
)
|
||||
capture_range.set_description(
|
||||
f"Capturing num tokens ({num_tokens=} {avail_mem=:.2f} GB)"
|
||||
)
|
||||
|
||||
with set_compiled(True):
|
||||
self.capture_one_batch_size(num_tokens)
|
||||
|
||||
# Save gemlite cache after each capture
|
||||
save_gemlite_cache()
|
||||
|
||||
def capture_one_batch_size(self, num_tokens: int):
|
||||
stream = self.stream
|
||||
bs = 1
|
||||
|
||||
# Graph inputs
|
||||
input_ids = self.input_ids[:num_tokens]
|
||||
out_cache_loc = self.out_cache_loc[:num_tokens]
|
||||
positions = self.positions[:num_tokens]
|
||||
|
||||
# pipeline parallelism
|
||||
if self.pp_size > 1:
|
||||
pp_proxy_tensors = PPProxyTensors(
|
||||
{k: v[:num_tokens] for k, v in self.pp_proxy_tensors.items()}
|
||||
)
|
||||
|
||||
global_dp_buffer_len = None
|
||||
|
||||
if self.model_runner.server_args.enable_lora:
|
||||
# It is safe to capture CUDA graph using empty LoRA id, as the LoRA kernels will always be launched whenever
|
||||
# `--enable-lora` is set to True (and return immediately if the LoRA id is empty for perf optimization).
|
||||
lora_ids = [None] * bs
|
||||
else:
|
||||
lora_ids = None
|
||||
|
||||
with torch.device(self.device):
|
||||
forward_batch = ForwardBatch(
|
||||
forward_mode=ForwardMode.EXTEND,
|
||||
batch_size=bs,
|
||||
input_ids=input_ids,
|
||||
req_pool_indices=torch.arange(bs, device=self.device),
|
||||
seq_lens=torch.tensor([num_tokens], device=self.device),
|
||||
next_token_logits_buffer=None,
|
||||
orig_seq_lens=torch.tensor([num_tokens], device=self.device),
|
||||
seq_lens_cpu=torch.tensor([num_tokens]),
|
||||
req_to_token_pool=self.model_runner.req_to_token_pool,
|
||||
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
||||
attn_backend=self.model_runner.attn_backend,
|
||||
out_cache_loc=out_cache_loc,
|
||||
seq_lens_sum=num_tokens,
|
||||
encoder_lens=None,
|
||||
return_logprob=False,
|
||||
extend_seq_lens=torch.tensor([num_tokens], device=self.device),
|
||||
extend_prefix_lens=torch.tensor([num_tokens], device=self.device),
|
||||
extend_start_loc=torch.tensor([0], device=self.device),
|
||||
extend_prefix_lens_cpu=torch.tensor([num_tokens]),
|
||||
extend_seq_lens_cpu=torch.tensor([num_tokens]),
|
||||
extend_logprob_start_lens_cpu=torch.tensor([num_tokens]),
|
||||
positions=positions,
|
||||
global_num_tokens_gpu=None,
|
||||
global_num_tokens_for_logprob_gpu=None,
|
||||
dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
|
||||
global_dp_buffer_len=None,
|
||||
mrope_positions=None,
|
||||
spec_algorithm=None,
|
||||
spec_info=None,
|
||||
capture_hidden_mode=CaptureHiddenMode.NULL,
|
||||
num_token_non_padded=None,
|
||||
global_forward_mode=ForwardMode.EXTEND,
|
||||
lora_ids=None,
|
||||
)
|
||||
self.tbo_plugin.capture_one_batch_size(forward_batch, num_tokens=num_tokens)
|
||||
|
||||
if lora_ids is not None:
|
||||
self.model_runner.lora_manager.prepare_lora_batch(forward_batch)
|
||||
|
||||
# # Attention backend
|
||||
self.model_runner.attn_backend.init_forward_metadata(forward_batch)
|
||||
|
||||
# Run and capture
|
||||
def run_once():
|
||||
# Clean intermediate result cache for DP attention
|
||||
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
|
||||
set_dp_buffer_len(global_dp_buffer_len, num_tokens)
|
||||
|
||||
kwargs = {}
|
||||
with set_forward_context(forward_batch, self.attention_layers):
|
||||
self.model_runner.model.forward(
|
||||
forward_batch.input_ids,
|
||||
forward_batch.positions,
|
||||
forward_batch,
|
||||
**kwargs,
|
||||
)
|
||||
return
|
||||
|
||||
for _ in range(2):
|
||||
self.device_module.synchronize()
|
||||
self.model_runner.tp_group.barrier()
|
||||
run_once()
|
||||
|
||||
return
|
||||
|
||||
def replay_prepare(
|
||||
self,
|
||||
forward_batch: ForwardBatch,
|
||||
**kwargs,
|
||||
):
|
||||
num_tokens = len(forward_batch.input_ids)
|
||||
index = bisect.bisect_left(self.capture_num_tokens, num_tokens)
|
||||
static_num_tokens = self.capture_num_tokens[index]
|
||||
self.raw_num_tokens = num_tokens
|
||||
if static_num_tokens != num_tokens:
|
||||
self.out_cache_loc.zero_()
|
||||
bs = forward_batch.batch_size
|
||||
|
||||
self.input_ids[:num_tokens].copy_(forward_batch.input_ids)
|
||||
self.positions[:num_tokens].copy_(forward_batch.positions)
|
||||
self.out_cache_loc[:num_tokens].copy_(forward_batch.out_cache_loc)
|
||||
|
||||
input_ids = self.input_ids[:static_num_tokens]
|
||||
positions = self.positions[:static_num_tokens]
|
||||
out_cache_loc = self.out_cache_loc[:static_num_tokens]
|
||||
|
||||
next_token_logits_buffer = None
|
||||
mrope_positions = None
|
||||
|
||||
static_forward_batch = ForwardBatch(
|
||||
forward_mode=forward_batch.forward_mode,
|
||||
batch_size=bs,
|
||||
input_ids=input_ids,
|
||||
req_pool_indices=forward_batch.req_pool_indices,
|
||||
seq_lens=forward_batch.seq_lens,
|
||||
next_token_logits_buffer=next_token_logits_buffer,
|
||||
orig_seq_lens=forward_batch.orig_seq_lens,
|
||||
seq_lens_cpu=forward_batch.seq_lens_cpu,
|
||||
req_to_token_pool=self.model_runner.req_to_token_pool,
|
||||
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
||||
attn_backend=self.model_runner.attn_backend,
|
||||
out_cache_loc=out_cache_loc,
|
||||
seq_lens_sum=forward_batch.seq_lens_sum,
|
||||
encoder_lens=forward_batch.encoder_lens,
|
||||
return_logprob=forward_batch.return_logprob,
|
||||
extend_seq_lens=forward_batch.extend_seq_lens,
|
||||
extend_prefix_lens=forward_batch.extend_prefix_lens,
|
||||
extend_start_loc=forward_batch.extend_start_loc,
|
||||
extend_prefix_lens_cpu=forward_batch.extend_prefix_lens_cpu,
|
||||
extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
|
||||
extend_logprob_start_lens_cpu=forward_batch.extend_logprob_start_lens_cpu,
|
||||
extend_num_tokens=forward_batch.extend_num_tokens,
|
||||
extend_input_logprob_token_ids_gpu=forward_batch.extend_input_logprob_token_ids_gpu,
|
||||
positions=positions,
|
||||
global_num_tokens_gpu=forward_batch.global_num_tokens_gpu,
|
||||
global_num_tokens_for_logprob_gpu=forward_batch.global_num_tokens_for_logprob_gpu,
|
||||
dp_padding_mode=forward_batch.dp_padding_mode,
|
||||
global_dp_buffer_len=forward_batch.global_dp_buffer_len,
|
||||
mrope_positions=mrope_positions,
|
||||
spec_algorithm=forward_batch.spec_algorithm,
|
||||
spec_info=forward_batch.spec_info,
|
||||
capture_hidden_mode=forward_batch.capture_hidden_mode,
|
||||
num_token_non_padded=forward_batch.num_token_non_padded,
|
||||
global_forward_mode=forward_batch.global_forward_mode,
|
||||
lora_ids=forward_batch.lora_ids,
|
||||
sampling_info=forward_batch.sampling_info,
|
||||
mm_inputs=forward_batch.mm_inputs,
|
||||
temp_scaled_logprobs=forward_batch.temp_scaled_logprobs,
|
||||
temperature=forward_batch.temperature,
|
||||
top_p_normalized_logprobs=forward_batch.top_p_normalized_logprobs,
|
||||
top_p=forward_batch.top_p,
|
||||
)
|
||||
|
||||
return static_forward_batch
|
||||
|
||||
def replay(
|
||||
self,
|
||||
forward_batch: ForwardBatch,
|
||||
**kwargs,
|
||||
) -> Union[LogitsProcessorOutput, PPProxyTensors]:
|
||||
static_forward_batch = self.replay_prepare(forward_batch, **kwargs)
|
||||
# Replay
|
||||
with set_forward_context(static_forward_batch, self.attention_layers):
|
||||
with set_compiled(True):
|
||||
output = self.model_runner.model.forward(
|
||||
static_forward_batch.input_ids,
|
||||
static_forward_batch.positions,
|
||||
static_forward_batch,
|
||||
**kwargs,
|
||||
)
|
||||
if isinstance(output, LogitsProcessorOutput):
|
||||
return LogitsProcessorOutput(
|
||||
next_token_logits=output.next_token_logits[: self.raw_num_tokens],
|
||||
hidden_states=(
|
||||
output.hidden_states[: self.raw_num_tokens]
|
||||
if output.hidden_states is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
else:
|
||||
assert isinstance(output, PPProxyTensors)
|
||||
# TODO(Yuwei): support PP Support
|
||||
raise NotImplementedError(
|
||||
"PPProxyTensors is not supported in PiecewiseCudaGraphRunner yet."
|
||||
)
|
||||
|
||||
def get_spec_info(self, num_tokens: int):
|
||||
spec_info = None
|
||||
if (
|
||||
self.model_runner.spec_algorithm.is_eagle()
|
||||
or self.model_runner.spec_algorithm.is_standalone()
|
||||
):
|
||||
from sglang.srt.speculative.eagle_utils import EagleVerifyInput
|
||||
|
||||
if self.model_runner.is_draft_worker:
|
||||
raise RuntimeError("This should not happen.")
|
||||
else:
|
||||
spec_info = EagleVerifyInput(
|
||||
draft_token=None,
|
||||
custom_mask=self.custom_mask,
|
||||
positions=None,
|
||||
retrive_index=None,
|
||||
retrive_next_token=None,
|
||||
retrive_next_sibling=None,
|
||||
retrive_cum_len=None,
|
||||
spec_steps=self.model_runner.server_args.speculative_num_steps,
|
||||
topk=self.model_runner.server_args.speculative_eagle_topk,
|
||||
draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens,
|
||||
capture_hidden_mode=CaptureHiddenMode.FULL,
|
||||
seq_lens_sum=None,
|
||||
seq_lens_cpu=None,
|
||||
)
|
||||
|
||||
return spec_info
|
||||
|
||||
|
||||
PIECEWISE_CUDA_GRAPH_CAPTURE_FAILED_MSG = (
|
||||
"Possible solutions:\n"
|
||||
"1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
|
||||
"2. set --piecewise-cuda-graph-max-tokens to a smaller value (e.g., 512)\n"
|
||||
"3. disable Piecewise CUDA graph by unset --enable-piecewise-cuda-graph\n"
|
||||
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
|
||||
)
|
||||
@@ -417,7 +417,10 @@ class ServerArgs:
|
||||
enable_single_batch_overlap: bool = False
|
||||
tbo_token_distribution_threshold: float = 0.48
|
||||
enable_torch_compile: bool = False
|
||||
enable_piecewise_cuda_graph: bool = False
|
||||
torch_compile_max_bs: int = 32
|
||||
piecewise_cuda_graph_max_tokens: int = 4096
|
||||
piecewise_cuda_graph_tokens: Optional[List[int]] = None
|
||||
torchao_config: str = ""
|
||||
enable_nan_detection: bool = False
|
||||
enable_p2p_check: bool = False
|
||||
@@ -675,6 +678,11 @@ class ServerArgs:
|
||||
else:
|
||||
self.cuda_graph_max_bs = max(self.cuda_graph_bs)
|
||||
|
||||
if self.piecewise_cuda_graph_tokens is None:
|
||||
self.piecewise_cuda_graph_tokens = (
|
||||
self._generate_piecewise_cuda_graph_tokens()
|
||||
)
|
||||
|
||||
if self.mem_fraction_static is None:
|
||||
# Constant meta data (e.g., from attention backend)
|
||||
reserved_mem = 512
|
||||
@@ -753,6 +761,25 @@ class ServerArgs:
|
||||
|
||||
return capture_bs
|
||||
|
||||
def _generate_piecewise_cuda_graph_tokens(self):
|
||||
"""
|
||||
Generate the list of batch sizes for piecewise CUDA graph capture
|
||||
based on piecewise_cuda_graph_max_tokens.
|
||||
"""
|
||||
capture_sizes = (
|
||||
list(range(4, 33, 4))
|
||||
+ list(range(48, 257, 16))
|
||||
+ list(range(288, 513, 32))
|
||||
+ list(range(640, 4096 + 1, 128))
|
||||
+ list(range(4352, self.piecewise_cuda_graph_max_tokens + 1, 256))
|
||||
)
|
||||
|
||||
capture_sizes = [
|
||||
s for s in capture_sizes if s <= self.piecewise_cuda_graph_max_tokens
|
||||
]
|
||||
|
||||
return capture_sizes
|
||||
|
||||
def _handle_hpu_backends(self):
|
||||
if self.device == "hpu":
|
||||
self.attention_backend = "torch_native"
|
||||
@@ -2649,12 +2676,29 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Optimize the model with torch.compile. Experimental feature.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-piecewise-cuda-graph",
|
||||
action="store_true",
|
||||
help="Optimize the model with piecewise cuda graph for extend/prefill only. Experimental feature.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--piecewise-cuda-graph-tokens",
|
||||
type=json_list_type,
|
||||
default=ServerArgs.piecewise_cuda_graph_tokens,
|
||||
help="Set the list of tokens when using piecewise cuda graph.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--torch-compile-max-bs",
|
||||
type=int,
|
||||
default=ServerArgs.torch_compile_max_bs,
|
||||
help="Set the maximum batch size when using torch compile.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--piecewise-cuda-graph-max-tokens",
|
||||
type=int,
|
||||
default=ServerArgs.piecewise_cuda_graph_max_tokens,
|
||||
help="Set the maximum tokens when using piecewise cuda graph.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--torchao-config",
|
||||
type=str,
|
||||
|
||||
Reference in New Issue
Block a user