Piecewise CUDA Graph Support & Torch Compile Backend (#10062)
Signed-off-by: Oasis-Git <ayw.sirius19@gmail.com>
This commit is contained in:
@@ -185,7 +185,7 @@ class CustomAllreduce:
|
|||||||
# is enough for 131072 such tuples. The largest model I've seen only
|
# is enough for 131072 such tuples. The largest model I've seen only
|
||||||
# needs less than 10000 of registered tuples.
|
# needs less than 10000 of registered tuples.
|
||||||
self.rank_data = torch.empty(
|
self.rank_data = torch.empty(
|
||||||
8 * 1024 * 1024, dtype=torch.uint8, device=self.device
|
max_size, dtype=torch.uint8, device=self.device
|
||||||
)
|
)
|
||||||
self._ptr = ops.init_custom_ar(
|
self._ptr = ops.init_custom_ar(
|
||||||
self.meta_ptrs, self.rank_data, rank, self.full_nvlink
|
self.meta_ptrs, self.rank_data, rank, self.full_nvlink
|
||||||
@@ -202,7 +202,7 @@ class CustomAllreduce:
|
|||||||
)
|
)
|
||||||
handles, offsets = self._gather_ipc_meta(shard_data)
|
handles, offsets = self._gather_ipc_meta(shard_data)
|
||||||
self.rank_data = torch.empty(
|
self.rank_data = torch.empty(
|
||||||
8 * 1024 * 1024, dtype=torch.uint8, device=self.device
|
max_size, dtype=torch.uint8, device=self.device
|
||||||
)
|
)
|
||||||
self._ptr = ops.init_custom_ar(
|
self._ptr = ops.init_custom_ar(
|
||||||
self.meta, self.rank_data, handles, offsets, rank, self.full_nvlink
|
self.meta, self.rank_data, handles, offsets, rank, self.full_nvlink
|
||||||
|
|||||||
@@ -239,6 +239,7 @@ class GroupCoordinator:
|
|||||||
use_npu_communicator: bool,
|
use_npu_communicator: bool,
|
||||||
use_message_queue_broadcaster: bool = False,
|
use_message_queue_broadcaster: bool = False,
|
||||||
group_name: Optional[str] = None,
|
group_name: Optional[str] = None,
|
||||||
|
torch_compile: Optional[bool] = None,
|
||||||
):
|
):
|
||||||
# Set group info
|
# Set group info
|
||||||
group_name = group_name or "anonymous"
|
group_name = group_name or "anonymous"
|
||||||
@@ -326,10 +327,18 @@ class GroupCoordinator:
|
|||||||
self.qr_comm: Optional[QuickAllReduce] = None
|
self.qr_comm: Optional[QuickAllReduce] = None
|
||||||
if use_custom_allreduce and self.world_size > 1:
|
if use_custom_allreduce and self.world_size > 1:
|
||||||
# Initialize a custom fast all-reduce implementation.
|
# Initialize a custom fast all-reduce implementation.
|
||||||
|
if torch_compile is not None and torch_compile:
|
||||||
|
# For piecewise CUDA graph, the requirement for custom allreduce is larger to
|
||||||
|
# avoid illegal cuda memory access.
|
||||||
|
ca_max_size = 256 * 1024 * 1024
|
||||||
|
else:
|
||||||
|
ca_max_size = 8 * 1024 * 1024
|
||||||
try:
|
try:
|
||||||
|
# print(f"ca_max_size: {ca_max_size}")
|
||||||
self.ca_comm = CustomAllreduce(
|
self.ca_comm = CustomAllreduce(
|
||||||
group=self.cpu_group,
|
group=self.cpu_group,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
|
max_size=ca_max_size,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -1260,6 +1269,7 @@ def init_model_parallel_group(
|
|||||||
group_name: Optional[str] = None,
|
group_name: Optional[str] = None,
|
||||||
use_mscclpp_allreduce: Optional[bool] = None,
|
use_mscclpp_allreduce: Optional[bool] = None,
|
||||||
use_symm_mem_allreduce: Optional[bool] = None,
|
use_symm_mem_allreduce: Optional[bool] = None,
|
||||||
|
torch_compile: Optional[bool] = None,
|
||||||
) -> GroupCoordinator:
|
) -> GroupCoordinator:
|
||||||
if use_custom_allreduce is None:
|
if use_custom_allreduce is None:
|
||||||
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
|
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
|
||||||
@@ -1280,6 +1290,7 @@ def init_model_parallel_group(
|
|||||||
use_npu_communicator=True,
|
use_npu_communicator=True,
|
||||||
use_message_queue_broadcaster=use_message_queue_broadcaster,
|
use_message_queue_broadcaster=use_message_queue_broadcaster,
|
||||||
group_name=group_name,
|
group_name=group_name,
|
||||||
|
torch_compile=torch_compile,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -1439,6 +1450,7 @@ def initialize_model_parallel(
|
|||||||
pipeline_model_parallel_size: int = 1,
|
pipeline_model_parallel_size: int = 1,
|
||||||
backend: Optional[str] = None,
|
backend: Optional[str] = None,
|
||||||
duplicate_tp_group: bool = False,
|
duplicate_tp_group: bool = False,
|
||||||
|
torch_compile: Optional[bool] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize model parallel groups.
|
Initialize model parallel groups.
|
||||||
@@ -1494,6 +1506,7 @@ def initialize_model_parallel(
|
|||||||
"SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true"
|
"SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true"
|
||||||
),
|
),
|
||||||
group_name="tp",
|
group_name="tp",
|
||||||
|
torch_compile=torch_compile,
|
||||||
)
|
)
|
||||||
|
|
||||||
if duplicate_tp_group:
|
if duplicate_tp_group:
|
||||||
@@ -1509,6 +1522,7 @@ def initialize_model_parallel(
|
|||||||
"SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true"
|
"SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true"
|
||||||
),
|
),
|
||||||
group_name="pdmux_prefill_tp",
|
group_name="pdmux_prefill_tp",
|
||||||
|
torch_compile=torch_compile,
|
||||||
)
|
)
|
||||||
_TP.pynccl_comm.disabled = False
|
_TP.pynccl_comm.disabled = False
|
||||||
_PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False
|
_PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False
|
||||||
@@ -1518,7 +1532,6 @@ def initialize_model_parallel(
|
|||||||
|
|
||||||
global _MOE_EP
|
global _MOE_EP
|
||||||
assert _MOE_EP is None, "expert model parallel group is already initialized"
|
assert _MOE_EP is None, "expert model parallel group is already initialized"
|
||||||
|
|
||||||
if moe_ep_size == tensor_model_parallel_size:
|
if moe_ep_size == tensor_model_parallel_size:
|
||||||
_MOE_EP = _TP
|
_MOE_EP = _TP
|
||||||
else:
|
else:
|
||||||
@@ -1539,7 +1552,6 @@ def initialize_model_parallel(
|
|||||||
|
|
||||||
global _MOE_TP
|
global _MOE_TP
|
||||||
assert _MOE_TP is None, "expert model parallel group is already initialized"
|
assert _MOE_TP is None, "expert model parallel group is already initialized"
|
||||||
|
|
||||||
if moe_tp_size == tensor_model_parallel_size:
|
if moe_tp_size == tensor_model_parallel_size:
|
||||||
_MOE_TP = _TP
|
_MOE_TP = _TP
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -43,11 +43,16 @@ _is_cpu = is_cpu()
|
|||||||
_is_xpu = is_xpu()
|
_is_xpu = is_xpu()
|
||||||
|
|
||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
if _is_flashinfer_available:
|
# if _is_flashinfer_available:
|
||||||
from flashinfer.norm import fused_add_rmsnorm
|
# from flashinfer.norm import fused_add_rmsnorm
|
||||||
else:
|
# else:
|
||||||
from sgl_kernel import fused_add_rmsnorm
|
from sgl_kernel import (
|
||||||
from sgl_kernel import gemma_fused_add_rmsnorm, gemma_rmsnorm, rmsnorm
|
fused_add_rmsnorm,
|
||||||
|
gemma_fused_add_rmsnorm,
|
||||||
|
gemma_rmsnorm,
|
||||||
|
rmsnorm,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if _use_aiter:
|
if _use_aiter:
|
||||||
from aiter import rmsnorm2d_fwd as rms_norm
|
from aiter import rmsnorm2d_fwd as rms_norm
|
||||||
|
|||||||
@@ -17,12 +17,18 @@ from __future__ import annotations
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
|
|
||||||
|
from sglang.srt.model_executor.compilation.piecewise_context_manager import (
|
||||||
|
get_forward_context,
|
||||||
|
)
|
||||||
|
from sglang.srt.utils import direct_register_custom_op
|
||||||
|
|
||||||
|
|
||||||
class AttentionType(Enum):
|
class AttentionType(Enum):
|
||||||
"""
|
"""
|
||||||
@@ -105,12 +111,58 @@ class RadixAttention(nn.Module):
|
|||||||
else:
|
else:
|
||||||
k = k.view(-1, self.tp_k_head_num, self.v_head_dim)
|
k = k.view(-1, self.tp_k_head_num, self.v_head_dim)
|
||||||
|
|
||||||
return forward_batch.attn_backend.forward(
|
if forward_batch.forward_mode.is_extend() and get_forward_context() is not None:
|
||||||
q,
|
output = torch.zeros_like(q)
|
||||||
k,
|
torch.ops.sglang.unified_attention_with_output(
|
||||||
v,
|
q, k, v, output, save_kv_cache, self.layer_id
|
||||||
self,
|
)
|
||||||
forward_batch,
|
return output
|
||||||
save_kv_cache,
|
else:
|
||||||
**kwargs,
|
return forward_batch.attn_backend.forward(
|
||||||
)
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
self,
|
||||||
|
forward_batch,
|
||||||
|
save_kv_cache,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def unified_attention_with_output(
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
output: torch.Tensor,
|
||||||
|
save_kv_cache: bool,
|
||||||
|
layer_id: int,
|
||||||
|
) -> None:
|
||||||
|
context = get_forward_context()
|
||||||
|
forward_batch = context.forward_batch
|
||||||
|
attention_layers = context.attention_layers
|
||||||
|
attention_layer = attention_layers[layer_id]
|
||||||
|
ret = forward_batch.attn_backend.forward(
|
||||||
|
query, key, value, attention_layer, forward_batch, save_kv_cache
|
||||||
|
)
|
||||||
|
assert output.shape == ret.shape
|
||||||
|
output.copy_(ret)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def unified_attention_with_output_fake(
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
output: torch.Tensor,
|
||||||
|
save_kv_cache: bool,
|
||||||
|
layer_id: int,
|
||||||
|
) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="unified_attention_with_output",
|
||||||
|
op_func=unified_attention_with_output,
|
||||||
|
mutates_args=["output"],
|
||||||
|
fake_impl=unified_attention_with_output_fake,
|
||||||
|
)
|
||||||
|
|||||||
435
python/sglang/srt/model_executor/compilation/backend.py
Normal file
435
python/sglang/srt/model_executor/compilation/backend.py
Normal file
@@ -0,0 +1,435 @@
|
|||||||
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/backend.py
|
||||||
|
|
||||||
|
|
||||||
|
import ast
|
||||||
|
import dataclasses
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import pprint
|
||||||
|
import time
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.fx as fx
|
||||||
|
from torch._dispatch.python import enable_python_dispatcher
|
||||||
|
|
||||||
|
from sglang.srt.model_executor.compilation.compilation_config import CompilationConfig
|
||||||
|
from sglang.srt.model_executor.compilation.compilation_counter import (
|
||||||
|
compilation_counter,
|
||||||
|
)
|
||||||
|
from sglang.srt.model_executor.compilation.compiler_interface import InductorAdaptor
|
||||||
|
from sglang.srt.model_executor.compilation.cuda_piecewise_backend import (
|
||||||
|
CUDAPiecewiseBackend,
|
||||||
|
)
|
||||||
|
from sglang.srt.model_executor.compilation.pass_manager import PostGradPassManager
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def make_compiler():
|
||||||
|
return InductorAdaptor()
|
||||||
|
|
||||||
|
|
||||||
|
class CompilerManager:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
self.cache = dict()
|
||||||
|
self.is_cache_updated = False
|
||||||
|
self.compiler = make_compiler()
|
||||||
|
|
||||||
|
def compute_hash(self):
|
||||||
|
return self.compiler.compute_hash()
|
||||||
|
|
||||||
|
def initialize_cache(
|
||||||
|
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
|
||||||
|
):
|
||||||
|
self.disable_cache = disable_cache
|
||||||
|
self.cache_dir = cache_dir
|
||||||
|
self.cache_file_path = os.path.join(cache_dir, "sglang_compile_cache.py")
|
||||||
|
|
||||||
|
if not disable_cache and os.path.exists(self.cache_file_path):
|
||||||
|
with open(self.cache_file_path) as f:
|
||||||
|
self.cache = ast.literal_eval(f.read())
|
||||||
|
|
||||||
|
self.compiler.initialize_cache(
|
||||||
|
cache_dir=cache_dir, disable_cache=disable_cache, prefix=prefix
|
||||||
|
)
|
||||||
|
|
||||||
|
def save_to_file(self):
|
||||||
|
if self.disable_cache or not self.is_cache_updated:
|
||||||
|
return
|
||||||
|
printer = pprint.PrettyPrinter(indent=4)
|
||||||
|
data = printer.pformat(self.cache)
|
||||||
|
with open(self.cache_file_path, "w") as f:
|
||||||
|
f.write(data)
|
||||||
|
|
||||||
|
def load(
|
||||||
|
self,
|
||||||
|
graph: fx.GraphModule,
|
||||||
|
example_inputs: list[Any],
|
||||||
|
graph_index: int,
|
||||||
|
runtime_shape: Optional[int] = None,
|
||||||
|
) -> Optional[Callable]:
|
||||||
|
handle = self.cache[(runtime_shape, graph_index, self.compiler.name)]
|
||||||
|
compiled_graph = self.compiler.load(
|
||||||
|
handle, graph, example_inputs, graph_index, runtime_shape
|
||||||
|
)
|
||||||
|
if runtime_shape is None:
|
||||||
|
logger.debug(
|
||||||
|
"Directly load the %s-th graph for dynamic shape from %s via "
|
||||||
|
"handle %s",
|
||||||
|
graph_index,
|
||||||
|
self.compiler.name,
|
||||||
|
handle,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug(
|
||||||
|
"Directly load the %s-th graph for shape %s from %s via " "handle %s",
|
||||||
|
graph_index,
|
||||||
|
str(runtime_shape),
|
||||||
|
self.compiler.name,
|
||||||
|
handle,
|
||||||
|
)
|
||||||
|
return compiled_graph
|
||||||
|
|
||||||
|
def compile(
|
||||||
|
self,
|
||||||
|
graph: fx.GraphModule,
|
||||||
|
example_inputs,
|
||||||
|
inductor_config: dict[str, Any],
|
||||||
|
graph_index: int = 0,
|
||||||
|
num_graphs: int = 1,
|
||||||
|
runtime_shape: Optional[int] = None,
|
||||||
|
) -> Any:
|
||||||
|
if graph_index == 0:
|
||||||
|
# before compiling the first graph, record the start time
|
||||||
|
global compilation_start_time
|
||||||
|
compilation_start_time = time.time()
|
||||||
|
|
||||||
|
compilation_counter.num_backend_compilations += 1
|
||||||
|
|
||||||
|
compiled_graph = None
|
||||||
|
|
||||||
|
# TODO(Yuwei): support cache loading
|
||||||
|
|
||||||
|
# no compiler cached the graph, or the cache is disabled,
|
||||||
|
# we need to compile it
|
||||||
|
if isinstance(self.compiler, InductorAdaptor):
|
||||||
|
maybe_key = None
|
||||||
|
else:
|
||||||
|
maybe_key = f"artifact_shape_{runtime_shape}_subgraph_{graph_index}"
|
||||||
|
compiled_graph, handle = self.compiler.compile(
|
||||||
|
graph, example_inputs, inductor_config, runtime_shape, maybe_key
|
||||||
|
)
|
||||||
|
|
||||||
|
assert compiled_graph is not None, "Failed to compile the graph"
|
||||||
|
|
||||||
|
# store the artifact in the cache
|
||||||
|
if handle is not None:
|
||||||
|
self.cache[(runtime_shape, graph_index, self.compiler.name)] = handle
|
||||||
|
compilation_counter.num_cache_entries_updated += 1
|
||||||
|
self.is_cache_updated = True
|
||||||
|
if graph_index == 0:
|
||||||
|
# adds some info logging for the first graph
|
||||||
|
if runtime_shape is None:
|
||||||
|
logger.info("Cache the graph for dynamic shape for later use")
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
"Cache the graph of shape %s for later use", str(runtime_shape)
|
||||||
|
)
|
||||||
|
if runtime_shape is None:
|
||||||
|
logger.debug(
|
||||||
|
"Store the %s-th graph for dynamic shape from %s via " "handle %s",
|
||||||
|
graph_index,
|
||||||
|
self.compiler.name,
|
||||||
|
handle,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug(
|
||||||
|
"Store the %s-th graph for shape %s from %s via handle %s",
|
||||||
|
graph_index,
|
||||||
|
str(runtime_shape),
|
||||||
|
self.compiler.name,
|
||||||
|
handle,
|
||||||
|
)
|
||||||
|
|
||||||
|
# after compiling the last graph, record the end time
|
||||||
|
if graph_index == num_graphs - 1:
|
||||||
|
now = time.time()
|
||||||
|
elapsed = now - compilation_start_time
|
||||||
|
if runtime_shape is None:
|
||||||
|
logger.info("Compiling a graph for dynamic shape takes %.2f s", elapsed)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
"Compiling a graph for shape %s takes %.2f s",
|
||||||
|
runtime_shape,
|
||||||
|
elapsed,
|
||||||
|
)
|
||||||
|
|
||||||
|
return compiled_graph
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class SplitItem:
|
||||||
|
submod_name: str
|
||||||
|
graph_id: int
|
||||||
|
is_splitting_graph: bool
|
||||||
|
graph: fx.GraphModule
|
||||||
|
|
||||||
|
|
||||||
|
def split_graph(
|
||||||
|
graph: fx.GraphModule, ops: list[str]
|
||||||
|
) -> tuple[fx.GraphModule, list[SplitItem]]:
|
||||||
|
# split graph by ops
|
||||||
|
subgraph_id = 0
|
||||||
|
node_to_subgraph_id = {}
|
||||||
|
split_op_graphs = []
|
||||||
|
for node in graph.graph.nodes:
|
||||||
|
if node.op in ("output", "placeholder"):
|
||||||
|
continue
|
||||||
|
if node.op == "call_function" and str(node.target) in ops:
|
||||||
|
subgraph_id += 1
|
||||||
|
node_to_subgraph_id[node] = subgraph_id
|
||||||
|
split_op_graphs.append(subgraph_id)
|
||||||
|
subgraph_id += 1
|
||||||
|
else:
|
||||||
|
node_to_subgraph_id[node] = subgraph_id
|
||||||
|
|
||||||
|
# `keep_original_order` is important!
|
||||||
|
# otherwise pytorch might reorder the nodes and
|
||||||
|
# the semantics of the graph will change when we
|
||||||
|
# have mutations in the graph
|
||||||
|
split_gm = torch.fx.passes.split_module.split_module(
|
||||||
|
graph, None, lambda node: node_to_subgraph_id[node], keep_original_order=True
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
|
||||||
|
names = [name for (name, module) in split_gm.named_modules()]
|
||||||
|
|
||||||
|
for name in names:
|
||||||
|
if "." in name or name == "":
|
||||||
|
# recursive child module or the root module
|
||||||
|
continue
|
||||||
|
|
||||||
|
module = getattr(split_gm, name)
|
||||||
|
|
||||||
|
graph_id = int(name.replace("submod_", ""))
|
||||||
|
outputs.append(SplitItem(name, graph_id, (graph_id in split_op_graphs), module))
|
||||||
|
|
||||||
|
# sort by intetger graph_id, rather than string name
|
||||||
|
outputs.sort(key=lambda x: x.graph_id)
|
||||||
|
|
||||||
|
return split_gm, outputs
|
||||||
|
|
||||||
|
|
||||||
|
# we share the global graph pool among all the backends
|
||||||
|
global_graph_pool = None
|
||||||
|
|
||||||
|
compilation_start_time = 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
module: torch.fx.GraphModule,
|
||||||
|
compile_submod_names: list[str],
|
||||||
|
inductor_config: dict[str, Any],
|
||||||
|
graph_pool,
|
||||||
|
compile_config: CompilationConfig,
|
||||||
|
sglang_backend: "SGLangBackend",
|
||||||
|
):
|
||||||
|
super().__init__(module)
|
||||||
|
from torch._guards import detect_fake_mode
|
||||||
|
|
||||||
|
self.fake_mode = detect_fake_mode()
|
||||||
|
self.compile_submod_names = compile_submod_names
|
||||||
|
self.graph_pool = graph_pool
|
||||||
|
self.sglang_backend = sglang_backend
|
||||||
|
# When True, it annoyingly dumps the torch.fx.Graph on errors.
|
||||||
|
self.extra_traceback = False
|
||||||
|
self.inductor_config = inductor_config
|
||||||
|
self.compile_config = compile_config
|
||||||
|
|
||||||
|
def run(self, *args):
|
||||||
|
fake_args = [
|
||||||
|
self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
|
||||||
|
for t in args
|
||||||
|
]
|
||||||
|
with self.fake_mode, enable_python_dispatcher():
|
||||||
|
return super().run(*fake_args)
|
||||||
|
|
||||||
|
def call_module(
|
||||||
|
self,
|
||||||
|
target: torch.fx.node.Target,
|
||||||
|
args: tuple[torch.fx.node.Argument, ...],
|
||||||
|
kwargs: dict[str, Any],
|
||||||
|
) -> Any:
|
||||||
|
assert isinstance(target, str)
|
||||||
|
output = super().call_module(target, args, kwargs)
|
||||||
|
|
||||||
|
if target in self.compile_submod_names:
|
||||||
|
index = self.compile_submod_names.index(target)
|
||||||
|
submod = self.fetch_attr(target)
|
||||||
|
sym_shape_indices = [
|
||||||
|
i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
|
||||||
|
]
|
||||||
|
global compilation_start_time
|
||||||
|
compiled_graph_for_dynamic_shape = (
|
||||||
|
self.sglang_backend.compiler_manager.compile(
|
||||||
|
submod,
|
||||||
|
args,
|
||||||
|
self.inductor_config,
|
||||||
|
graph_index=index,
|
||||||
|
num_graphs=len(self.compile_submod_names),
|
||||||
|
runtime_shape=None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.module.__dict__[target] = CUDAPiecewiseBackend(
|
||||||
|
submod,
|
||||||
|
self.compile_config,
|
||||||
|
self.inductor_config,
|
||||||
|
self.graph_pool,
|
||||||
|
index,
|
||||||
|
len(self.compile_submod_names),
|
||||||
|
sym_shape_indices,
|
||||||
|
compiled_graph_for_dynamic_shape,
|
||||||
|
self.sglang_backend,
|
||||||
|
)
|
||||||
|
|
||||||
|
compilation_counter.num_piecewise_capturable_graphs_seen += 1
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
model_tag: str = "backbone"
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def set_model_tag(tag: str):
|
||||||
|
"""Context manager to set the model tag."""
|
||||||
|
global model_tag
|
||||||
|
assert (
|
||||||
|
tag != model_tag
|
||||||
|
), f"Model tag {tag} is the same as the current tag {model_tag}."
|
||||||
|
old_tag = model_tag
|
||||||
|
model_tag = tag
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
model_tag = old_tag
|
||||||
|
|
||||||
|
|
||||||
|
class SGLangBackend:
|
||||||
|
|
||||||
|
graph_pool: Any
|
||||||
|
_called: bool = False
|
||||||
|
# the graph we compiled
|
||||||
|
graph: fx.GraphModule
|
||||||
|
# the stiching graph module for all the piecewise graphs
|
||||||
|
split_gm: fx.GraphModule
|
||||||
|
piecewise_graphs: list[SplitItem]
|
||||||
|
returned_callable: Callable
|
||||||
|
# Inductor passes to run on the graph pre-defunctionalization
|
||||||
|
post_grad_passes: Sequence[Callable]
|
||||||
|
sym_tensor_indices: list[int]
|
||||||
|
input_buffers: list[torch.Tensor]
|
||||||
|
compiler_manager: CompilerManager
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: CompilationConfig,
|
||||||
|
graph_pool: Any,
|
||||||
|
):
|
||||||
|
assert graph_pool is not None
|
||||||
|
self.graph_pool = graph_pool
|
||||||
|
|
||||||
|
self.post_grad_pass_manager = PostGradPassManager()
|
||||||
|
self.sym_tensor_indices = []
|
||||||
|
self.input_buffers = []
|
||||||
|
|
||||||
|
self.compiler_manager = CompilerManager()
|
||||||
|
self.inductor_config = {
|
||||||
|
"enable_auto_functionalized_v2": False,
|
||||||
|
}
|
||||||
|
self.compile_config = config
|
||||||
|
|
||||||
|
def configure_post_pass(self):
|
||||||
|
self.post_grad_pass_manager.configure()
|
||||||
|
self.inductor_config["post_grad_custom_post_pass"] = self.post_grad_pass_manager
|
||||||
|
|
||||||
|
def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
|
||||||
|
base_cache_dir = os.path.expanduser(
|
||||||
|
os.getenv("SGLANG_CACHE_DIR", "~/.cache/sglang/")
|
||||||
|
)
|
||||||
|
|
||||||
|
cache_hash = self.compiler_manager.compute_hash()
|
||||||
|
cache_dir = os.path.join(
|
||||||
|
base_cache_dir,
|
||||||
|
"torch_compile_cache",
|
||||||
|
cache_hash,
|
||||||
|
)
|
||||||
|
|
||||||
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
|
rank = 0
|
||||||
|
dp_rank = 0
|
||||||
|
local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}", model_tag)
|
||||||
|
os.makedirs(local_cache_dir, exist_ok=True)
|
||||||
|
self.compiler_manager.initialize_cache(
|
||||||
|
local_cache_dir, disable_cache=False, prefix=""
|
||||||
|
)
|
||||||
|
compilation_counter.num_graphs_seen += 1
|
||||||
|
|
||||||
|
assert not self._called, "SGLangBackend can only be called once"
|
||||||
|
|
||||||
|
self.graph = graph
|
||||||
|
self.configure_post_pass()
|
||||||
|
|
||||||
|
self.split_gm, self.piecewise_graphs = split_graph(
|
||||||
|
graph, ["sglang.unified_attention_with_output"]
|
||||||
|
)
|
||||||
|
|
||||||
|
from torch._dynamo.utils import lazy_format_graph_code
|
||||||
|
|
||||||
|
# depyf will hook lazy_format_graph_code and dump the graph
|
||||||
|
# for debugging, no need to print the graph here
|
||||||
|
lazy_format_graph_code("before split", self.graph)
|
||||||
|
lazy_format_graph_code("after split", self.split_gm)
|
||||||
|
|
||||||
|
compilation_counter.num_piecewise_graphs_seen += len(self.piecewise_graphs)
|
||||||
|
|
||||||
|
submod_names_to_compile = [
|
||||||
|
item.submod_name
|
||||||
|
for item in self.piecewise_graphs
|
||||||
|
if not item.is_splitting_graph
|
||||||
|
]
|
||||||
|
|
||||||
|
PiecewiseCompileInterpreter(
|
||||||
|
self.split_gm,
|
||||||
|
submod_names_to_compile,
|
||||||
|
self.inductor_config,
|
||||||
|
self.graph_pool,
|
||||||
|
self.compile_config,
|
||||||
|
self,
|
||||||
|
).run(*example_inputs)
|
||||||
|
|
||||||
|
graph_path = os.path.join(local_cache_dir, "computation_graph.py")
|
||||||
|
if not os.path.exists(graph_path):
|
||||||
|
# code adapted from https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30 # noqa
|
||||||
|
# use `print_readable` because it can include submodules
|
||||||
|
src = (
|
||||||
|
"from __future__ import annotations\nimport torch\n"
|
||||||
|
+ self.split_gm.print_readable(print_output=False)
|
||||||
|
)
|
||||||
|
src = src.replace("<lambda>", "GraphModule")
|
||||||
|
with open(graph_path, "w") as f:
|
||||||
|
f.write(src)
|
||||||
|
|
||||||
|
logger.debug("Computation graph saved to %s", graph_path)
|
||||||
|
|
||||||
|
self._called = True
|
||||||
|
return self.split_gm
|
||||||
@@ -0,0 +1,19 @@
|
|||||||
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/compilation_config.py
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(Yuwei): support better compile config support
|
||||||
|
class CompilationConfig:
|
||||||
|
def __init__(self, capture_sizes: List[int]):
|
||||||
|
self.traced_files = set()
|
||||||
|
self.capture_sizes = capture_sizes
|
||||||
|
|
||||||
|
def add_traced_file(self, file_path: str):
|
||||||
|
self.traced_files.add(file_path)
|
||||||
|
|
||||||
|
def get_traced_files(self):
|
||||||
|
return self.traced_files
|
||||||
|
|
||||||
|
def get_capture_sizes(self):
|
||||||
|
return self.capture_sizes
|
||||||
@@ -0,0 +1,47 @@
|
|||||||
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/compilation_counter.py
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import dataclasses
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class CompilationCounter:
|
||||||
|
num_models_seen: int = 0
|
||||||
|
num_graphs_seen: int = 0
|
||||||
|
# including the splitting ops
|
||||||
|
num_piecewise_graphs_seen: int = 0
|
||||||
|
# not including the splitting ops
|
||||||
|
num_piecewise_capturable_graphs_seen: int = 0
|
||||||
|
num_backend_compilations: int = 0
|
||||||
|
# Number of gpu_model_runner attempts to trigger CUDAGraphs capture
|
||||||
|
num_gpu_runner_capture_triggers: int = 0
|
||||||
|
# Number of CUDAGraphs captured
|
||||||
|
num_cudagraph_captured: int = 0
|
||||||
|
# InductorAdapter.compile calls
|
||||||
|
num_inductor_compiles: int = 0
|
||||||
|
# EagerAdapter.compile calls
|
||||||
|
num_eager_compiles: int = 0
|
||||||
|
# The number of time vLLM's compiler cache entry was updated
|
||||||
|
num_cache_entries_updated: int = 0
|
||||||
|
# The number of standalone_compile compiled artifacts saved
|
||||||
|
num_compiled_artifacts_saved: int = 0
|
||||||
|
# Number of times a model was loaded with CompilationLevel.DYNAMO_AS_IS
|
||||||
|
dynamo_as_is_count: int = 0
|
||||||
|
|
||||||
|
def clone(self) -> "CompilationCounter":
|
||||||
|
return copy.deepcopy(self)
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def expect(self, **kwargs):
|
||||||
|
old = self.clone()
|
||||||
|
yield
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
assert getattr(self, k) - getattr(old, k) == v, (
|
||||||
|
f"{k} not as expected, before it is {getattr(old, k)}"
|
||||||
|
f", after it is {getattr(self, k)}, "
|
||||||
|
f"expected diff is {v}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
compilation_counter = CompilationCounter()
|
||||||
210
python/sglang/srt/model_executor/compilation/compile.py
Normal file
210
python/sglang/srt/model_executor/compilation/compile.py
Normal file
@@ -0,0 +1,210 @@
|
|||||||
|
import contextvars
|
||||||
|
import inspect
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import types
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.model_executor.compilation.compilation_config import CompilationConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_COMPILE_ENABLED = contextvars.ContextVar("_COMPILE_ENABLED", default=False)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def set_compiled(enabled: bool = True):
|
||||||
|
token = _COMPILE_ENABLED.set(enabled)
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
_COMPILE_ENABLED.reset(token)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class IntermediateTensors:
|
||||||
|
"""For all pipeline stages except the last, we need to return the hidden
|
||||||
|
states and residuals to be sent to the next stage. This data structure
|
||||||
|
contains the hidden states and residuals for a request.
|
||||||
|
|
||||||
|
Each stage also needs to handle its own finished_sending and
|
||||||
|
finished_recving in case of kv transfer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
tensors: dict[str, torch.Tensor]
|
||||||
|
# [req_ids]
|
||||||
|
finished_sending: Optional[set[str]] = None
|
||||||
|
finished_recving: Optional[set[str]] = None
|
||||||
|
|
||||||
|
def __init__(self, tensors):
|
||||||
|
# manually define this function, so that
|
||||||
|
# Dynamo knows `IntermediateTensors()` comes from this file.
|
||||||
|
# Otherwise, dataclass will generate this function by evaluating
|
||||||
|
# a string, and we will lose the information about the source file.
|
||||||
|
self.tensors = tensors
|
||||||
|
|
||||||
|
def __getitem__(self, key: Union[str, slice]):
|
||||||
|
if isinstance(key, str):
|
||||||
|
return self.tensors[key]
|
||||||
|
elif isinstance(key, slice):
|
||||||
|
return self.__class__({k: v[key] for k, v in self.tensors.items()})
|
||||||
|
|
||||||
|
def __setitem__(self, key: str, value: torch.Tensor):
|
||||||
|
self.tensors[key] = value
|
||||||
|
|
||||||
|
def items(self):
|
||||||
|
return self.tensors.items()
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.tensors)
|
||||||
|
|
||||||
|
def __eq__(self, other: object):
|
||||||
|
return isinstance(other, self.__class__) and self
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"IntermediateTensors(tensors={self.tensors})"
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_dims(dims, ndim: int):
|
||||||
|
dims = [dims] if isinstance(dims, int) else list(dims)
|
||||||
|
return [d if d >= 0 else ndim + d for d in dims]
|
||||||
|
|
||||||
|
|
||||||
|
class _MaybeIntermediateTensors:
|
||||||
|
"""Duck-typed check to support your IntermediateTensors without importing."""
|
||||||
|
|
||||||
|
def __init__(self, obj):
|
||||||
|
self.is_intermediate = hasattr(obj, "tensors") and isinstance(
|
||||||
|
getattr(obj, "tensors"), dict
|
||||||
|
)
|
||||||
|
self.obj = obj
|
||||||
|
|
||||||
|
|
||||||
|
def _mark_dynamic_on_value(val, dims):
|
||||||
|
if isinstance(val, torch.Tensor):
|
||||||
|
torch._dynamo.mark_dynamic(val, _normalize_dims(dims, val.ndim))
|
||||||
|
else:
|
||||||
|
mit = _MaybeIntermediateTensors(val)
|
||||||
|
if mit.is_intermediate:
|
||||||
|
for t in mit.obj.tensors.values():
|
||||||
|
torch._dynamo.mark_dynamic(t, _normalize_dims(dims, t.ndim))
|
||||||
|
# else: ignore (None or non-tensor)
|
||||||
|
|
||||||
|
|
||||||
|
def _infer_dynamic_arg_dims_from_annotations(forward_fn):
|
||||||
|
sig = inspect.signature(forward_fn)
|
||||||
|
dyn = {}
|
||||||
|
for name, p in sig.parameters.items():
|
||||||
|
ann = p.annotation
|
||||||
|
# Accept torch.Tensor / Optional[torch.Tensor] / your IntermediateTensors types by name
|
||||||
|
if (
|
||||||
|
ann is torch.Tensor
|
||||||
|
or getattr(getattr(ann, "__args__", [None])[0], "__name__", "") == "Tensor"
|
||||||
|
):
|
||||||
|
dyn[name] = 0
|
||||||
|
elif getattr(ann, "__name__", "") in ("IntermediateTensors",) or any(
|
||||||
|
getattr(a, "__name__", "") == "IntermediateTensors"
|
||||||
|
for a in getattr(ann, "__args__", [])
|
||||||
|
):
|
||||||
|
dyn[name] = 0
|
||||||
|
if not dyn:
|
||||||
|
raise ValueError("No dynamic dims inferred; pass dynamic_arg_dims explicitly.")
|
||||||
|
return dyn
|
||||||
|
|
||||||
|
|
||||||
|
def install_torch_compiled(
|
||||||
|
module: torch.nn.Module,
|
||||||
|
*,
|
||||||
|
dynamic_arg_dims: dict[str, Union[int, list[int]]] | None = None,
|
||||||
|
backend_factory: Optional[Callable[[torch.fx.GraphModule, list], Callable]] = None,
|
||||||
|
compile_config: CompilationConfig = None,
|
||||||
|
fullgraph: bool = True,
|
||||||
|
graph_pool: Any = None,
|
||||||
|
):
|
||||||
|
unbound_fwd = module.__class__.forward
|
||||||
|
if not callable(unbound_fwd):
|
||||||
|
raise TypeError("module.__class__.forward must be callable")
|
||||||
|
original_code = unbound_fwd.__code__
|
||||||
|
|
||||||
|
dyn_map = dynamic_arg_dims or _infer_dynamic_arg_dims_from_annotations(unbound_fwd)
|
||||||
|
|
||||||
|
if backend_factory is None:
|
||||||
|
from sglang.srt.model_executor.compilation.backend import SGLangBackend
|
||||||
|
|
||||||
|
backend_factory = lambda gm, ex: SGLangBackend(compile_config, graph_pool)(
|
||||||
|
gm, ex
|
||||||
|
)
|
||||||
|
|
||||||
|
compiled_codes: list[type(original_code)] = []
|
||||||
|
state = {"compiled": False, "compiled_callable": None}
|
||||||
|
|
||||||
|
def bytecode_hook(old_code, new_code):
|
||||||
|
if old_code is not original_code:
|
||||||
|
return
|
||||||
|
frame = sys._getframe()
|
||||||
|
while frame and frame.f_back:
|
||||||
|
frame = frame.f_back
|
||||||
|
if (
|
||||||
|
frame.f_code.co_name == "_compile"
|
||||||
|
and os.path.basename(frame.f_code.co_filename) == "convert_frame.py"
|
||||||
|
):
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
dynamo_frame = frame.f_locals["frame"]
|
||||||
|
except Exception:
|
||||||
|
return
|
||||||
|
if dynamo_frame.f_code is not old_code:
|
||||||
|
return
|
||||||
|
if dynamo_frame.f_locals.get("self") is not module:
|
||||||
|
return
|
||||||
|
compiled_codes.append(new_code)
|
||||||
|
|
||||||
|
torch._dynamo.convert_frame.register_bytecode_hook(bytecode_hook)
|
||||||
|
|
||||||
|
def _ensure_compiled(self, *args, **kwargs):
|
||||||
|
"""Compile on first use (with flag ON)."""
|
||||||
|
if state["compiled"]:
|
||||||
|
return
|
||||||
|
# Mark dynamic dims only when we are about to compile
|
||||||
|
sig = inspect.signature(unbound_fwd)
|
||||||
|
ba = sig.bind(self, *args, **kwargs)
|
||||||
|
ba.apply_defaults()
|
||||||
|
for name, dims in (dyn_map or {}).items():
|
||||||
|
if name in ba.arguments:
|
||||||
|
val = ba.arguments[name]
|
||||||
|
if val is not None:
|
||||||
|
_mark_dynamic_on_value(val, dims)
|
||||||
|
|
||||||
|
# Avoid cross-instance cache reuse
|
||||||
|
torch._dynamo.eval_frame.remove_from_cache(unbound_fwd.__code__)
|
||||||
|
|
||||||
|
bound = types.MethodType(unbound_fwd, self)
|
||||||
|
compiled_callable = torch.compile(
|
||||||
|
bound, fullgraph=fullgraph, backend=backend_factory
|
||||||
|
)
|
||||||
|
|
||||||
|
# Trigger Dynamo so bytecode hook can capture
|
||||||
|
compiled_callable(*args, **kwargs)
|
||||||
|
|
||||||
|
state["compiled"] = True
|
||||||
|
state["compiled_callable"] = compiled_callable
|
||||||
|
|
||||||
|
def trampoline(self, *args, **kwargs):
|
||||||
|
use_compiled = _COMPILE_ENABLED.get()
|
||||||
|
if use_compiled:
|
||||||
|
if not state["compiled"]:
|
||||||
|
_ensure_compiled(self, *args, **kwargs)
|
||||||
|
|
||||||
|
compiled_callable = state["compiled_callable"]
|
||||||
|
return compiled_callable(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
# Explicitly run the original uncompiled forward
|
||||||
|
return unbound_fwd(self, *args, **kwargs)
|
||||||
|
|
||||||
|
module.forward = types.MethodType(trampoline, module)
|
||||||
|
return module
|
||||||
@@ -0,0 +1,479 @@
|
|||||||
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/compiler_interface.py
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import copy
|
||||||
|
import hashlib
|
||||||
|
import os
|
||||||
|
from contextlib import ExitStack
|
||||||
|
from typing import Any, Callable, Optional
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch._inductor.compile_fx
|
||||||
|
import torch.fx as fx
|
||||||
|
|
||||||
|
from sglang.srt.model_executor.compilation.compilation_counter import (
|
||||||
|
compilation_counter,
|
||||||
|
)
|
||||||
|
from sglang.srt.model_executor.compilation.inductor_pass import pass_context
|
||||||
|
|
||||||
|
|
||||||
|
class CompilerInterface:
|
||||||
|
"""
|
||||||
|
The interface for a compiler that can be used by vLLM.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# The name of the compiler, e.g. inductor.
|
||||||
|
# This is a class-level attribute.
|
||||||
|
name: str
|
||||||
|
|
||||||
|
def initialize_cache(
|
||||||
|
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
when the vLLM process uses `cache_dir` as the cache directory,
|
||||||
|
the compiler should initialize itself with the cache directory,
|
||||||
|
e.g. by re-directing its own cache directory to a sub-directory.
|
||||||
|
|
||||||
|
prefix can be used in combination with cache_dir to figure out the base
|
||||||
|
cache directory, e.g. there're multiple parts of model being compiled,
|
||||||
|
but we want to share the same cache directory for all of them.
|
||||||
|
|
||||||
|
e.g.
|
||||||
|
cache_dir = "/path/to/dir/backbone", prefix = "backbone"
|
||||||
|
cache_dir = "/path/to/dir/eagle_head", prefix = "eagle_head"
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def compute_hash(self) -> str:
|
||||||
|
"""
|
||||||
|
Gather all the relevant information from the vLLM config,
|
||||||
|
to compute a hash so that we can cache the compiled model.
|
||||||
|
|
||||||
|
See [`VllmConfig.compute_hash`][vllm.config.VllmConfig.compute_hash]
|
||||||
|
to check what information
|
||||||
|
is already considered by default. This function should only
|
||||||
|
consider the information that is specific to the compiler.
|
||||||
|
"""
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def compile(
|
||||||
|
self,
|
||||||
|
graph: fx.GraphModule,
|
||||||
|
example_inputs: list[Any],
|
||||||
|
compiler_config: dict[str, Any],
|
||||||
|
runtime_shape: Optional[int] = None,
|
||||||
|
key: Optional[str] = None,
|
||||||
|
) -> tuple[Optional[Callable], Optional[Any]]:
|
||||||
|
"""
|
||||||
|
Compile the graph with the given example inputs and compiler config,
|
||||||
|
with a runtime shape. If the `runtime_shape` is None, it means
|
||||||
|
the `example_inputs` have a dynamic shape. Otherwise, the
|
||||||
|
`runtime_shape` specifies the shape of the inputs. Right now we only
|
||||||
|
support one variable shape for all inputs, which is the batchsize
|
||||||
|
(number of tokens) during inference.
|
||||||
|
|
||||||
|
Dynamo will make sure `graph(*example_inputs)` is valid.
|
||||||
|
|
||||||
|
The function should return a compiled callable function, as well as
|
||||||
|
a handle that can be used to directly load the compiled function.
|
||||||
|
|
||||||
|
The handle should be a plain Python object, preferably a string or a
|
||||||
|
file path for readability.
|
||||||
|
|
||||||
|
If the compiler doesn't support caching, it should return None for the
|
||||||
|
handle. If the compiler fails to compile the graph, it should return
|
||||||
|
None for the compiled function as well.
|
||||||
|
|
||||||
|
`key` is required for StandaloneInductorAdapter, it specifies where to
|
||||||
|
save the compiled artifact. The compiled artifact gets saved to
|
||||||
|
`cache_dir/key`.
|
||||||
|
"""
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
def load(
|
||||||
|
self,
|
||||||
|
handle: Any,
|
||||||
|
graph: fx.GraphModule,
|
||||||
|
example_inputs: list[Any],
|
||||||
|
graph_index: int,
|
||||||
|
runtime_shape: Optional[int] = None,
|
||||||
|
) -> Callable:
|
||||||
|
"""
|
||||||
|
Load the compiled function from the handle.
|
||||||
|
Raises an error if the handle is invalid.
|
||||||
|
|
||||||
|
The handle is the second return value of the `compile` function.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("caching is not supported")
|
||||||
|
|
||||||
|
|
||||||
|
def get_inductor_factors() -> list[Any]:
|
||||||
|
factors: list[Any] = []
|
||||||
|
# summarize system state
|
||||||
|
from torch._inductor.codecache import CacheBase
|
||||||
|
|
||||||
|
system_factors = CacheBase.get_system()
|
||||||
|
factors.append(system_factors)
|
||||||
|
|
||||||
|
# summarize pytorch state
|
||||||
|
from torch._inductor.codecache import torch_key
|
||||||
|
|
||||||
|
torch_factors = torch_key()
|
||||||
|
factors.append(torch_factors)
|
||||||
|
return factors
|
||||||
|
|
||||||
|
|
||||||
|
class AlwaysHitShapeEnv:
|
||||||
|
"""
|
||||||
|
Why do we need this class:
|
||||||
|
|
||||||
|
For normal `torch.compile` usage, every compilation will have
|
||||||
|
one Dynamo bytecode compilation and one Inductor compilation.
|
||||||
|
The Inductor compilation happens under the context of the
|
||||||
|
Dynamo bytecode compilation, and that context is used to
|
||||||
|
determine the dynamic shape information, etc.
|
||||||
|
|
||||||
|
For our use case, we only run Dynamo bytecode compilation once,
|
||||||
|
and run Inductor compilation multiple times with different shapes
|
||||||
|
plus a general shape. The compilation for specific shapes happens
|
||||||
|
outside of the context of the Dynamo bytecode compilation. At that
|
||||||
|
time, we don't have shape environment to provide to Inductor, and
|
||||||
|
it will fail the Inductor code cache lookup.
|
||||||
|
|
||||||
|
By providing a dummy shape environment that always hits, we can
|
||||||
|
make the Inductor code cache lookup always hit, and we can
|
||||||
|
compile the graph for different shapes as needed.
|
||||||
|
|
||||||
|
The following dummy methods are obtained by trial-and-error
|
||||||
|
until it works.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.guards: list[Any] = []
|
||||||
|
|
||||||
|
def evaluate_guards_expression(self, *args, **kwargs):
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_pruned_guards(self, *args, **kwargs):
|
||||||
|
return []
|
||||||
|
|
||||||
|
def produce_guards_expression(self, *args, **kwargs):
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
class InductorAdaptor(CompilerInterface):
|
||||||
|
"""
|
||||||
|
The adaptor for the Inductor compiler, version 2.5, 2.6, 2.7.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "inductor"
|
||||||
|
|
||||||
|
def compute_hash(self) -> str:
|
||||||
|
factors = get_inductor_factors()
|
||||||
|
hash_str = hashlib.md5(
|
||||||
|
str(factors).encode(), usedforsecurity=False
|
||||||
|
).hexdigest()[:10]
|
||||||
|
return hash_str
|
||||||
|
|
||||||
|
def initialize_cache(
|
||||||
|
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
|
||||||
|
):
|
||||||
|
self.cache_dir = cache_dir
|
||||||
|
self.prefix = prefix
|
||||||
|
self.base_cache_dir = cache_dir[: -len(prefix)] if prefix else cache_dir
|
||||||
|
if disable_cache:
|
||||||
|
return
|
||||||
|
# redirect the cache directory to a sub-directory
|
||||||
|
# set flags so that Inductor and Triton store their cache
|
||||||
|
# in the cache_dir, then users only need to copy the cache_dir
|
||||||
|
# to another machine to reuse the cache.
|
||||||
|
inductor_cache = os.path.join(self.base_cache_dir, "inductor_cache")
|
||||||
|
os.makedirs(inductor_cache, exist_ok=True)
|
||||||
|
os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache
|
||||||
|
triton_cache = os.path.join(self.base_cache_dir, "triton_cache")
|
||||||
|
os.makedirs(triton_cache, exist_ok=True)
|
||||||
|
os.environ["TRITON_CACHE_DIR"] = triton_cache
|
||||||
|
|
||||||
|
def compile(
|
||||||
|
self,
|
||||||
|
graph: fx.GraphModule,
|
||||||
|
example_inputs: list[Any],
|
||||||
|
compiler_config: dict[str, Any],
|
||||||
|
runtime_shape: Optional[int] = None,
|
||||||
|
key: Optional[str] = None,
|
||||||
|
) -> tuple[Optional[Callable], Optional[Any]]:
|
||||||
|
compilation_counter.num_inductor_compiles += 1
|
||||||
|
from torch._inductor.compile_fx import compile_fx
|
||||||
|
|
||||||
|
current_config = {}
|
||||||
|
if compiler_config is not None:
|
||||||
|
current_config.update(compiler_config)
|
||||||
|
|
||||||
|
# disable remote cache
|
||||||
|
current_config["fx_graph_cache"] = True
|
||||||
|
current_config["fx_graph_remote_cache"] = False
|
||||||
|
|
||||||
|
set_inductor_config(current_config, runtime_shape)
|
||||||
|
|
||||||
|
# inductor can inplace modify the graph, so we need to copy it
|
||||||
|
# see https://github.com/pytorch/pytorch/issues/138980
|
||||||
|
graph = copy.deepcopy(graph)
|
||||||
|
|
||||||
|
# it's the first time we compile this graph
|
||||||
|
# the assumption is that we don't have nested Inductor compilation.
|
||||||
|
# compiled_fx_graph_hash will only be called once, and we can hook
|
||||||
|
# it to get the hash of the compiled graph directly.
|
||||||
|
|
||||||
|
hash_str, file_path = None, None
|
||||||
|
from torch._inductor.codecache import FxGraphCache, compiled_fx_graph_hash
|
||||||
|
|
||||||
|
if torch.__version__.startswith("2.5"):
|
||||||
|
original_load = FxGraphCache.load
|
||||||
|
original_load_name = "torch._inductor.codecache.FxGraphCache.load"
|
||||||
|
|
||||||
|
def hijack_load(*args, **kwargs):
|
||||||
|
inductor_compiled_graph = original_load(*args, **kwargs)
|
||||||
|
nonlocal file_path
|
||||||
|
compiled_fn = inductor_compiled_graph.current_callable
|
||||||
|
file_path = compiled_fn.__code__.co_filename # noqa
|
||||||
|
if not file_path.startswith(self.base_cache_dir):
|
||||||
|
# hooked in the align_inputs_from_check_idxs function
|
||||||
|
# in torch/_inductor/utils.py
|
||||||
|
for cell in compiled_fn.__closure__:
|
||||||
|
if not callable(cell.cell_contents):
|
||||||
|
continue
|
||||||
|
if cell.cell_contents.__code__.co_filename.startswith(
|
||||||
|
self.base_cache_dir
|
||||||
|
):
|
||||||
|
# this is the real file path compiled from Inductor
|
||||||
|
file_path = cell.cell_contents.__code__.co_filename
|
||||||
|
break
|
||||||
|
return inductor_compiled_graph
|
||||||
|
|
||||||
|
hijacked_compile_fx_inner = (
|
||||||
|
torch._inductor.compile_fx.compile_fx_inner
|
||||||
|
) # noqa
|
||||||
|
elif torch.__version__ >= "2.6":
|
||||||
|
# function renamed in 2.6
|
||||||
|
original_load_name = None
|
||||||
|
|
||||||
|
def hijacked_compile_fx_inner(*args, **kwargs):
|
||||||
|
output = torch._inductor.compile_fx.compile_fx_inner(*args, **kwargs)
|
||||||
|
nonlocal hash_str
|
||||||
|
inductor_compiled_graph = output
|
||||||
|
if inductor_compiled_graph is not None:
|
||||||
|
nonlocal file_path
|
||||||
|
compiled_fn = inductor_compiled_graph.current_callable
|
||||||
|
file_path = compiled_fn.__code__.co_filename # noqa
|
||||||
|
if not file_path.startswith(self.base_cache_dir):
|
||||||
|
# hooked in the align_inputs_from_check_idxs function
|
||||||
|
# in torch/_inductor/utils.py
|
||||||
|
for cell in compiled_fn.__closure__:
|
||||||
|
if not callable(cell.cell_contents):
|
||||||
|
continue
|
||||||
|
code = cell.cell_contents.__code__
|
||||||
|
if code.co_filename.startswith(self.base_cache_dir):
|
||||||
|
# this is the real file path
|
||||||
|
# compiled from Inductor
|
||||||
|
file_path = code.co_filename
|
||||||
|
break
|
||||||
|
hash_str = inductor_compiled_graph._fx_graph_cache_key
|
||||||
|
return output
|
||||||
|
|
||||||
|
def hijack_compiled_fx_graph_hash(*args, **kwargs):
|
||||||
|
out = compiled_fx_graph_hash(*args, **kwargs)
|
||||||
|
nonlocal hash_str
|
||||||
|
hash_str = out[0]
|
||||||
|
return out
|
||||||
|
|
||||||
|
def _check_can_cache(*args, **kwargs):
|
||||||
|
# no error means it can be cached.
|
||||||
|
# Inductor refuses to cache the graph outside of Dynamo
|
||||||
|
# tracing context, and also disables caching for graphs
|
||||||
|
# with high-order ops.
|
||||||
|
# For vLLM, in either case, we want to cache the graph.
|
||||||
|
# see https://github.com/pytorch/pytorch/blob/9f5ebf3fc609105a74eab4ccc24932d6353ff566/torch/_inductor/codecache.py#L1221 # noqa
|
||||||
|
return
|
||||||
|
|
||||||
|
def _get_shape_env() -> AlwaysHitShapeEnv:
|
||||||
|
return AlwaysHitShapeEnv()
|
||||||
|
|
||||||
|
with ExitStack() as stack:
|
||||||
|
# hijack to get the compiled graph itself
|
||||||
|
if original_load_name is not None:
|
||||||
|
stack.enter_context(patch(original_load_name, hijack_load))
|
||||||
|
|
||||||
|
# for hijacking the hash of the compiled graph
|
||||||
|
stack.enter_context(
|
||||||
|
patch(
|
||||||
|
"torch._inductor.codecache.compiled_fx_graph_hash",
|
||||||
|
hijack_compiled_fx_graph_hash,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# for providing a dummy shape environment
|
||||||
|
stack.enter_context(
|
||||||
|
patch(
|
||||||
|
"torch._inductor.codecache.FxGraphCache._get_shape_env",
|
||||||
|
_get_shape_env,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
|
||||||
|
|
||||||
|
# torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
|
||||||
|
if hasattr(AOTAutogradCache, "_get_shape_env"):
|
||||||
|
stack.enter_context(
|
||||||
|
patch(
|
||||||
|
"torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env",
|
||||||
|
_get_shape_env,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# for forcing the graph to be cached
|
||||||
|
stack.enter_context(
|
||||||
|
patch(
|
||||||
|
"torch._inductor.codecache.FxGraphCache._check_can_cache",
|
||||||
|
_check_can_cache,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Dynamo metrics context, see method for more details.
|
||||||
|
stack.enter_context(self.metrics_context())
|
||||||
|
|
||||||
|
# Disable remote caching. When these are on, on remote cache-hit,
|
||||||
|
# the monkey-patched functions never actually get called.
|
||||||
|
# vLLM today assumes and requires the monkey-patched functions to
|
||||||
|
# get hit.
|
||||||
|
# TODO(zou3519): we're going to replace this all with
|
||||||
|
# standalone_compile sometime.
|
||||||
|
|
||||||
|
stack.enter_context(
|
||||||
|
torch._inductor.config.patch(fx_graph_remote_cache=False)
|
||||||
|
)
|
||||||
|
# InductorAdaptor (unfortunately) requires AOTAutogradCache
|
||||||
|
# to be turned off to run. It will fail to acquire the hash_str
|
||||||
|
# and error if not.
|
||||||
|
# StandaloneInductorAdaptor (PyTorch 2.8+) fixes this problem.
|
||||||
|
stack.enter_context(
|
||||||
|
torch._functorch.config.patch(enable_autograd_cache=False)
|
||||||
|
)
|
||||||
|
stack.enter_context(
|
||||||
|
torch._functorch.config.patch(enable_remote_autograd_cache=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
with pass_context(runtime_shape):
|
||||||
|
compiled_graph = compile_fx(
|
||||||
|
graph,
|
||||||
|
example_inputs,
|
||||||
|
inner_compile=hijacked_compile_fx_inner,
|
||||||
|
config_patches=current_config,
|
||||||
|
)
|
||||||
|
return compiled_graph, (hash_str, file_path)
|
||||||
|
|
||||||
|
def load(
|
||||||
|
self,
|
||||||
|
handle: Any,
|
||||||
|
graph: fx.GraphModule,
|
||||||
|
example_inputs: list[Any],
|
||||||
|
graph_index: int,
|
||||||
|
runtime_shape: Optional[int] = None,
|
||||||
|
) -> Callable:
|
||||||
|
assert isinstance(handle, tuple)
|
||||||
|
assert isinstance(handle[0], str)
|
||||||
|
assert isinstance(handle[1], str)
|
||||||
|
hash_str = handle[0]
|
||||||
|
|
||||||
|
from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
|
||||||
|
from torch._inductor.codecache import FxGraphCache
|
||||||
|
|
||||||
|
with ExitStack() as exit_stack:
|
||||||
|
exit_stack.enter_context(
|
||||||
|
patch(
|
||||||
|
"torch._inductor.codecache.FxGraphCache._get_shape_env",
|
||||||
|
lambda *args, **kwargs: AlwaysHitShapeEnv(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
|
||||||
|
if hasattr(AOTAutogradCache, "_get_shape_env"):
|
||||||
|
exit_stack.enter_context(
|
||||||
|
patch(
|
||||||
|
"torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env",
|
||||||
|
lambda *args, **kwargs: AlwaysHitShapeEnv(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Dynamo metrics context, see method for more details.
|
||||||
|
exit_stack.enter_context(self.metrics_context())
|
||||||
|
|
||||||
|
if torch.__version__.startswith("2.5"):
|
||||||
|
inductor_compiled_graph = FxGraphCache._lookup_graph(
|
||||||
|
hash_str, example_inputs, True, False
|
||||||
|
)
|
||||||
|
assert inductor_compiled_graph is not None, (
|
||||||
|
"Inductor cache lookup failed. Please remove"
|
||||||
|
f"the cache directory and try again." # noqa
|
||||||
|
)
|
||||||
|
elif torch.__version__ >= "2.6":
|
||||||
|
from torch._inductor.output_code import CompiledFxGraphConstantsWithGm
|
||||||
|
|
||||||
|
constants = CompiledFxGraphConstantsWithGm(graph)
|
||||||
|
inductor_compiled_graph, _ = FxGraphCache._lookup_graph(
|
||||||
|
hash_str, example_inputs, True, None, constants
|
||||||
|
)
|
||||||
|
assert inductor_compiled_graph is not None, (
|
||||||
|
"Inductor cache lookup failed. Please remove"
|
||||||
|
f"the cache directory and try again." # noqa
|
||||||
|
)
|
||||||
|
|
||||||
|
# Inductor calling convention (function signature):
|
||||||
|
# f(list) -> tuple
|
||||||
|
# Dynamo calling convention (function signature):
|
||||||
|
# f(*args) -> Any
|
||||||
|
|
||||||
|
# need to know if the graph returns a tuple
|
||||||
|
from torch._inductor.compile_fx import graph_returns_tuple
|
||||||
|
|
||||||
|
returns_tuple = graph_returns_tuple(graph)
|
||||||
|
|
||||||
|
# this is the callable we return to Dynamo to run
|
||||||
|
def compiled_graph(*args):
|
||||||
|
# convert args to list
|
||||||
|
list_args = list(args)
|
||||||
|
graph_output = inductor_compiled_graph(list_args)
|
||||||
|
# unpack the tuple if needed
|
||||||
|
if returns_tuple:
|
||||||
|
return graph_output
|
||||||
|
else:
|
||||||
|
return graph_output[0]
|
||||||
|
|
||||||
|
return compiled_graph
|
||||||
|
|
||||||
|
def metrics_context(self) -> contextlib.AbstractContextManager:
|
||||||
|
"""
|
||||||
|
This method returns the Dynamo metrics context (if it exists,
|
||||||
|
otherwise a null context). It is used by various compile components.
|
||||||
|
Present in torch>=2.6, it's used inside FxGraphCache in
|
||||||
|
torch==2.6 (but not after). It might also be used in various other
|
||||||
|
torch.compile internal functions.
|
||||||
|
|
||||||
|
Because it is re-entrant, we always set it (even if entering via Dynamo
|
||||||
|
and the context was already entered). We might want to revisit if it
|
||||||
|
should be set at a different level of compilation.
|
||||||
|
|
||||||
|
This is likely a bug in PyTorch: public APIs should not rely on
|
||||||
|
manually setting up internal contexts. But we also rely on non-public
|
||||||
|
APIs which might not provide these guarantees.
|
||||||
|
"""
|
||||||
|
import torch._dynamo.utils
|
||||||
|
|
||||||
|
return torch._dynamo.utils.get_metrics_context()
|
||||||
|
|
||||||
|
|
||||||
|
def set_inductor_config(config, runtime_shape):
|
||||||
|
if isinstance(runtime_shape, int):
|
||||||
|
# for a specific batchsize, tuning triton kernel parameters
|
||||||
|
# can be beneficial
|
||||||
|
config["max_autotune"] = True
|
||||||
|
config["coordinate_descent_tuning"] = True
|
||||||
@@ -0,0 +1,230 @@
|
|||||||
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/cuda_piecewise_backend.py
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
import logging
|
||||||
|
from contextlib import ExitStack
|
||||||
|
from typing import Any, Callable, Optional, Union
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.fx as fx
|
||||||
|
|
||||||
|
import sglang.srt.model_executor.compilation.weak_ref_tensor_jit
|
||||||
|
from sglang.srt.model_executor.compilation.compilation_config import CompilationConfig
|
||||||
|
from sglang.srt.model_executor.compilation.compilation_counter import (
|
||||||
|
compilation_counter,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def weak_ref_tensor(tensor: Any) -> Any:
|
||||||
|
"""
|
||||||
|
Create a weak reference to a tensor.
|
||||||
|
The new tensor will share the same data as the original tensor,
|
||||||
|
but will not keep the original tensor alive.
|
||||||
|
"""
|
||||||
|
if isinstance(tensor, torch.Tensor):
|
||||||
|
# TODO(yuwei): introduce weak_ref_tensor from sgl_kernel
|
||||||
|
return torch.ops.jit_weak_ref_tensor.weak_ref_tensor(tensor)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
def weak_ref_tensors(
|
||||||
|
tensors: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]]
|
||||||
|
) -> Union[torch.Tensor, list[Any], tuple[Any], Any]:
|
||||||
|
"""
|
||||||
|
Convenience function to create weak references to tensors,
|
||||||
|
for single tensor, list of tensors or tuple of tensors.
|
||||||
|
"""
|
||||||
|
if isinstance(tensors, torch.Tensor):
|
||||||
|
return weak_ref_tensor(tensors)
|
||||||
|
if isinstance(tensors, list):
|
||||||
|
return [weak_ref_tensor(t) for t in tensors]
|
||||||
|
if isinstance(tensors, tuple):
|
||||||
|
return tuple(weak_ref_tensor(t) for t in tensors)
|
||||||
|
raise ValueError("Invalid type for tensors")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class ConcreteSizeEntry:
|
||||||
|
runtime_shape: int
|
||||||
|
need_to_compile: bool # the size is in compile_sizes
|
||||||
|
use_cudagraph: bool # the size is in cudagraph_capture_sizes
|
||||||
|
|
||||||
|
compiled: bool = False
|
||||||
|
runnable: Callable = None # type: ignore
|
||||||
|
num_finished_warmup: int = 0
|
||||||
|
cudagraph: Optional[torch.cuda.CUDAGraph] = None
|
||||||
|
output: Optional[Any] = None
|
||||||
|
|
||||||
|
# for cudagraph debugging, track the input addresses
|
||||||
|
# during capture, and check if they are the same during replay
|
||||||
|
input_addresses: Optional[list[int]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class CUDAPiecewiseBackend:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
graph: fx.GraphModule,
|
||||||
|
compile_config: CompilationConfig,
|
||||||
|
inductor_config: dict[str, Any],
|
||||||
|
graph_pool: Any,
|
||||||
|
piecewise_compile_index: int,
|
||||||
|
total_piecewise_compiles: int,
|
||||||
|
sym_shape_indices: list[int],
|
||||||
|
compiled_graph_for_general_shape: Callable,
|
||||||
|
sglang_backend,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
The backend for piecewise compilation.
|
||||||
|
It mainly handles the compilation and cudagraph capturing.
|
||||||
|
|
||||||
|
We will compile `self.graph` once for the general shape,
|
||||||
|
and then compile for different shapes specified in
|
||||||
|
`compilation_config.compile_sizes`.
|
||||||
|
|
||||||
|
Independently, we will capture cudagraph for different shapes.
|
||||||
|
|
||||||
|
If a shape needs both compilation and cudagraph, we will
|
||||||
|
compile it first, and then capture cudagraph.
|
||||||
|
"""
|
||||||
|
self.graph = graph
|
||||||
|
self.inductor_config = inductor_config
|
||||||
|
self.graph_pool = graph_pool
|
||||||
|
self.piecewise_compile_index = piecewise_compile_index
|
||||||
|
self.total_piecewise_compiles = total_piecewise_compiles
|
||||||
|
self.sglang_backend = sglang_backend
|
||||||
|
|
||||||
|
self.is_first_graph = piecewise_compile_index == 0
|
||||||
|
self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1
|
||||||
|
|
||||||
|
self.compile_sizes: set[int] = set([])
|
||||||
|
self.compile_config = compile_config
|
||||||
|
self.cudagraph_capture_sizes: set[int] = set(compile_config.get_capture_sizes())
|
||||||
|
|
||||||
|
self.first_run_finished = False
|
||||||
|
|
||||||
|
self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa
|
||||||
|
|
||||||
|
self.sym_shape_indices = sym_shape_indices
|
||||||
|
|
||||||
|
self.is_debugging_mode = True
|
||||||
|
|
||||||
|
# the entries for different shapes that we need to either
|
||||||
|
# compile or capture cudagraph
|
||||||
|
self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {}
|
||||||
|
|
||||||
|
# to_be_compiled_sizes tracks the remaining sizes to compile,
|
||||||
|
# and updates during the compilation process, so we need to copy it
|
||||||
|
self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy()
|
||||||
|
for shape in self.compile_sizes.union(self.cudagraph_capture_sizes):
|
||||||
|
self.concrete_size_entries[shape] = ConcreteSizeEntry(
|
||||||
|
runtime_shape=shape,
|
||||||
|
need_to_compile=shape in self.compile_sizes,
|
||||||
|
use_cudagraph=shape in self.cudagraph_capture_sizes,
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_for_ending_compilation(self):
|
||||||
|
if self.is_last_graph and not self.to_be_compiled_sizes:
|
||||||
|
# no specific sizes to compile
|
||||||
|
# save the hash of the inductor graph for the next run
|
||||||
|
self.sglang_backend.compiler_manager.save_to_file()
|
||||||
|
|
||||||
|
def __call__(self, *args) -> Any:
|
||||||
|
if not self.first_run_finished:
|
||||||
|
self.first_run_finished = True
|
||||||
|
self.check_for_ending_compilation()
|
||||||
|
return self.compiled_graph_for_general_shape(*args)
|
||||||
|
runtime_shape = args[self.sym_shape_indices[0]]
|
||||||
|
if runtime_shape not in self.concrete_size_entries:
|
||||||
|
# we don't need to do anything for this shape
|
||||||
|
return self.compiled_graph_for_general_shape(*args)
|
||||||
|
|
||||||
|
entry = self.concrete_size_entries[runtime_shape]
|
||||||
|
|
||||||
|
if entry.runnable is None:
|
||||||
|
entry.runnable = self.compiled_graph_for_general_shape
|
||||||
|
|
||||||
|
if entry.need_to_compile and not entry.compiled:
|
||||||
|
entry.compiled = True
|
||||||
|
self.to_be_compiled_sizes.remove(runtime_shape)
|
||||||
|
# args are real arguments
|
||||||
|
entry.runnable = self.sglang_backend.compiler_manager.compile(
|
||||||
|
self.graph,
|
||||||
|
args,
|
||||||
|
self.inductor_config,
|
||||||
|
graph_index=self.piecewise_compile_index,
|
||||||
|
num_graphs=self.total_piecewise_compiles,
|
||||||
|
runtime_shape=runtime_shape,
|
||||||
|
)
|
||||||
|
|
||||||
|
# finished compilations for all required shapes
|
||||||
|
if self.is_last_graph and not self.to_be_compiled_sizes:
|
||||||
|
self.check_for_ending_compilation()
|
||||||
|
|
||||||
|
# Skip CUDA graphs if this entry doesn't use them OR
|
||||||
|
# if we're supposed to skip them globally
|
||||||
|
# skip_cuda_graphs = get_forward_context().skip_cuda_graphs
|
||||||
|
# if not entry.use_cudagraph or skip_cuda_graphs:
|
||||||
|
# return entry.runnable(*args)
|
||||||
|
|
||||||
|
if entry.cudagraph is None:
|
||||||
|
if entry.num_finished_warmup < 1: # noqa
|
||||||
|
entry.num_finished_warmup += 1
|
||||||
|
return entry.runnable(*args)
|
||||||
|
|
||||||
|
input_addresses = [
|
||||||
|
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
||||||
|
]
|
||||||
|
entry.input_addresses = input_addresses
|
||||||
|
cudagraph = torch.cuda.CUDAGraph()
|
||||||
|
|
||||||
|
with ExitStack() as stack:
|
||||||
|
if not self.is_first_graph:
|
||||||
|
# during every model forward, we will capture
|
||||||
|
# many pieces of cudagraphs (roughly one per layer).
|
||||||
|
# running gc again and again across layers will
|
||||||
|
# make the cudagraph capture very slow.
|
||||||
|
# therefore, we only run gc for the first graph,
|
||||||
|
# and disable gc for the rest of the graphs.
|
||||||
|
stack.enter_context(patch("gc.collect", lambda: None))
|
||||||
|
stack.enter_context(patch("torch.cuda.empty_cache", lambda: None))
|
||||||
|
|
||||||
|
# mind-exploding: carefully manage the reference and memory.
|
||||||
|
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
|
||||||
|
# `output` is managed by pytorch's cudagraph pool
|
||||||
|
output = entry.runnable(*args)
|
||||||
|
if self.is_last_graph:
|
||||||
|
# by converting it to weak ref,
|
||||||
|
# the original `output` will immediately be released
|
||||||
|
# to save memory. It is only safe to do this for
|
||||||
|
# the last graph, because the output of the last graph
|
||||||
|
# will not be used by any other cuda graph.
|
||||||
|
output = weak_ref_tensors(output)
|
||||||
|
|
||||||
|
# here we always use weak ref for the output
|
||||||
|
# to save memory
|
||||||
|
entry.output = weak_ref_tensors(output)
|
||||||
|
entry.cudagraph = cudagraph
|
||||||
|
|
||||||
|
compilation_counter.num_cudagraph_captured += 1
|
||||||
|
|
||||||
|
# important: we need to return the output, rather than
|
||||||
|
# the weak ref of the output, so that pytorch can correctly
|
||||||
|
# manage the memory during cuda graph capture
|
||||||
|
return output
|
||||||
|
|
||||||
|
if self.is_debugging_mode:
|
||||||
|
# check if the input addresses are the same
|
||||||
|
new_input_addresses = [
|
||||||
|
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
||||||
|
]
|
||||||
|
assert new_input_addresses == entry.input_addresses, (
|
||||||
|
"Input addresses for cudagraphs are different during replay."
|
||||||
|
f" Expected {entry.input_addresses}, got {new_input_addresses}"
|
||||||
|
)
|
||||||
|
|
||||||
|
entry.cudagraph.replay()
|
||||||
|
return entry.output
|
||||||
@@ -0,0 +1,134 @@
|
|||||||
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/fix_functionalization.py
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import operator
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||||
|
|
||||||
|
from sglang.srt.model_executor.compilation.fx_utils import is_func
|
||||||
|
from sglang.srt.model_executor.compilation.inductor_pass import SGLangInductorPass
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FixFunctionalizationPass(SGLangInductorPass):
|
||||||
|
"""
|
||||||
|
This pass defunctionalizes certain nodes to avoid redundant tensor copies.
|
||||||
|
After this pass, DCE (dead-code elimination) should never be run,
|
||||||
|
as de-functionalized nodes may appear as dead code.
|
||||||
|
|
||||||
|
To add new nodes to defunctionalize, add to the if-elif chain in __call__.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(self, graph: torch.fx.Graph):
|
||||||
|
self.begin()
|
||||||
|
self.dump_graph(graph, "before_fix_functionalization")
|
||||||
|
|
||||||
|
self.nodes_to_remove: list[torch.fx.Node] = []
|
||||||
|
count = 0
|
||||||
|
for node in graph.nodes:
|
||||||
|
if not is_func(node, auto_functionalized):
|
||||||
|
continue # Avoid deep if-elif nesting
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
self.dump_graph(graph, "before_fix_functionalization_cleanup")
|
||||||
|
|
||||||
|
# Remove the nodes all at once
|
||||||
|
count_removed = len(self.nodes_to_remove)
|
||||||
|
for node in self.nodes_to_remove:
|
||||||
|
graph.erase_node(node)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"De-functionalized %s nodes, removed %s nodes", count, count_removed
|
||||||
|
)
|
||||||
|
self.dump_graph(graph, "after_fix_functionalization")
|
||||||
|
self.end_and_log()
|
||||||
|
|
||||||
|
def _remove(self, node_or_nodes: Union[torch.fx.Node, Iterable[torch.fx.Node]]):
|
||||||
|
"""
|
||||||
|
Stage a node (or nodes) for removal at the end of the pass.
|
||||||
|
"""
|
||||||
|
if isinstance(node_or_nodes, torch.fx.Node):
|
||||||
|
self.nodes_to_remove.append(node_or_nodes)
|
||||||
|
else:
|
||||||
|
self.nodes_to_remove.extend(node_or_nodes)
|
||||||
|
|
||||||
|
def defunctionalize(
|
||||||
|
self,
|
||||||
|
graph: torch.fx.Graph,
|
||||||
|
node: torch.fx.Node,
|
||||||
|
mutated_args: dict[int, Union[torch.fx.Node, str]],
|
||||||
|
args: Optional[tuple[Union[torch.fx.Node, str], ...]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
De-functionalize a node by replacing it with a call to the original.
|
||||||
|
It also replaces the getitem users with the mutated arguments.
|
||||||
|
See replace_users_with_mutated_args and insert_defunctionalized.
|
||||||
|
"""
|
||||||
|
self.replace_users_with_mutated_args(node, mutated_args)
|
||||||
|
self.insert_defunctionalized(graph, node, args=args)
|
||||||
|
self._remove(node)
|
||||||
|
|
||||||
|
def replace_users_with_mutated_args(
|
||||||
|
self, node: torch.fx.Node, mutated_args: dict[int, Union[torch.fx.Node, str]]
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Replace all getitem users of the auto-functionalized node with the
|
||||||
|
mutated arguments.
|
||||||
|
:param node: The auto-functionalized node
|
||||||
|
:param mutated_args: The mutated arguments, indexed by getitem index.
|
||||||
|
If the value of an arg is a string, `node.kwargs[arg]` is used.
|
||||||
|
"""
|
||||||
|
for idx, user in self.getitem_users(node).items():
|
||||||
|
arg = mutated_args[idx]
|
||||||
|
arg = node.kwargs[arg] if isinstance(arg, str) else arg
|
||||||
|
user.replace_all_uses_with(arg)
|
||||||
|
self._remove(user)
|
||||||
|
|
||||||
|
def getitem_users(self, node: torch.fx.Node) -> dict[int, torch.fx.Node]:
|
||||||
|
"""
|
||||||
|
Returns the operator.getitem users of the auto-functionalized node,
|
||||||
|
indexed by the index they are getting.
|
||||||
|
"""
|
||||||
|
users = {}
|
||||||
|
for user in node.users:
|
||||||
|
if is_func(user, operator.getitem):
|
||||||
|
idx = user.args[1]
|
||||||
|
users[idx] = user
|
||||||
|
return users
|
||||||
|
|
||||||
|
def insert_defunctionalized(
|
||||||
|
self,
|
||||||
|
graph: torch.fx.Graph,
|
||||||
|
node: torch.fx.Node,
|
||||||
|
args: Optional[tuple[Union[torch.fx.Node, str], ...]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Insert a new defunctionalized node into the graph before node.
|
||||||
|
If one of the kwargs is 'out', provide args directly,
|
||||||
|
as node.kwargs cannot be used.
|
||||||
|
See https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351
|
||||||
|
|
||||||
|
:param graph: Graph to insert the defunctionalized node into
|
||||||
|
:param node: The auto-functionalized node to defunctionalize
|
||||||
|
:param args: If we cannot use kwargs, specify args directly.
|
||||||
|
If an arg is a string, `node.kwargs[arg]` is used.
|
||||||
|
""" # noqa: E501
|
||||||
|
assert is_func(
|
||||||
|
node, auto_functionalized
|
||||||
|
), f"node must be auto-functionalized, is {node} instead"
|
||||||
|
|
||||||
|
# Create a new call to the original function
|
||||||
|
with graph.inserting_before(node):
|
||||||
|
function = node.args[0]
|
||||||
|
if args is None:
|
||||||
|
graph.call_function(function, kwargs=node.kwargs)
|
||||||
|
else:
|
||||||
|
# Args passed as strings refer to items in node.kwargs
|
||||||
|
args = tuple(
|
||||||
|
node.kwargs[arg] if isinstance(arg, str) else arg for arg in args
|
||||||
|
)
|
||||||
|
graph.call_function(function, args=args)
|
||||||
83
python/sglang/srt/model_executor/compilation/fx_utils.py
Normal file
83
python/sglang/srt/model_executor/compilation/fx_utils.py
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/fx_utils.py
|
||||||
|
|
||||||
|
import operator
|
||||||
|
from collections.abc import Iterable, Iterator
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from torch import fx
|
||||||
|
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||||
|
from torch._ops import OpOverload
|
||||||
|
|
||||||
|
|
||||||
|
def is_func(node: fx.Node, target) -> bool:
|
||||||
|
return node.op == "call_function" and node.target == target
|
||||||
|
|
||||||
|
|
||||||
|
def is_auto_func(node: fx.Node, op: OpOverload) -> bool:
|
||||||
|
return is_func(node, auto_functionalized) and node.args[0] == op
|
||||||
|
|
||||||
|
|
||||||
|
# Returns the first specified node with the given op (if it exists)
|
||||||
|
def find_specified_fn_maybe(
|
||||||
|
nodes: Iterable[fx.Node], op: OpOverload
|
||||||
|
) -> Optional[fx.Node]:
|
||||||
|
for node in nodes:
|
||||||
|
if node.target == op:
|
||||||
|
return node
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# Returns the first specified node with the given op
|
||||||
|
def find_specified_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node:
|
||||||
|
node = find_specified_fn_maybe(nodes, op)
|
||||||
|
assert node is not None, f"Could not find {op} in nodes {nodes}"
|
||||||
|
return node
|
||||||
|
|
||||||
|
|
||||||
|
# Returns the first auto_functionalized node with the given op (if it exists)
|
||||||
|
def find_auto_fn_maybe(nodes: Iterable[fx.Node], op: OpOverload) -> Optional[fx.Node]:
|
||||||
|
for node in nodes:
|
||||||
|
if is_func(node, auto_functionalized) and node.args[0] == op: # noqa
|
||||||
|
return node
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# Returns the first auto_functionalized node with the given op
|
||||||
|
def find_auto_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node:
|
||||||
|
node = find_auto_fn_maybe(nodes, op)
|
||||||
|
assert node is not None, f"Could not find {op} in nodes {nodes}"
|
||||||
|
return node
|
||||||
|
|
||||||
|
|
||||||
|
# Returns the getitem node that extracts the idx-th element from node
|
||||||
|
# (if it exists)
|
||||||
|
def find_getitem_maybe(node: fx.Node, idx: int) -> Optional[fx.Node]:
|
||||||
|
for user in node.users:
|
||||||
|
if is_func(user, operator.getitem) and user.args[1] == idx:
|
||||||
|
return user
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# Returns the getitem node that extracts the idx-th element from node
|
||||||
|
def find_getitem(node: fx.Node, idx: int) -> fx.Node:
|
||||||
|
ret = find_getitem_maybe(node, idx)
|
||||||
|
assert ret is not None, f"Could not find getitem {idx} in node {node}"
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
# An auto-functionalization-aware utility for finding nodes with a specific op
|
||||||
|
def find_op_nodes(op: OpOverload, graph: fx.Graph) -> Iterator[fx.Node]:
|
||||||
|
if not op._schema.is_mutable:
|
||||||
|
yield from graph.find_nodes(op="call_function", target=op)
|
||||||
|
|
||||||
|
for n in graph.find_nodes(op="call_function", target=auto_functionalized):
|
||||||
|
if n.args[0] == op:
|
||||||
|
yield n
|
||||||
|
|
||||||
|
|
||||||
|
# Asserts that the node only has one user and returns it
|
||||||
|
# Even if a node has only 1 user, it might share storage with another node,
|
||||||
|
# which might need to be taken into account.
|
||||||
|
def get_only_user(node: fx.Node) -> fx.Node:
|
||||||
|
assert len(node.users) == 1
|
||||||
|
return next(iter(node.users))
|
||||||
140
python/sglang/srt/model_executor/compilation/inductor_pass.py
Normal file
140
python/sglang/srt/model_executor/compilation/inductor_pass.py
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/inductor_pass.py
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import inspect
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import types
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import fx
|
||||||
|
from torch._dynamo.utils import lazy_format_graph_code
|
||||||
|
from torch._inductor.custom_graph_pass import CustomGraphPass
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_pass_context = None
|
||||||
|
|
||||||
|
|
||||||
|
class PassContext:
|
||||||
|
|
||||||
|
def __init__(self, runtime_shape: Optional[int]):
|
||||||
|
self.runtime_shape = runtime_shape
|
||||||
|
|
||||||
|
|
||||||
|
def get_pass_context() -> PassContext:
|
||||||
|
"""Get the current pass context."""
|
||||||
|
assert _pass_context is not None
|
||||||
|
return _pass_context
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def pass_context(runtime_shape: Optional[int]):
|
||||||
|
"""A context manager that stores the current pass context,
|
||||||
|
usually it is a list of sizes to specialize.
|
||||||
|
"""
|
||||||
|
global _pass_context
|
||||||
|
prev_context = _pass_context
|
||||||
|
_pass_context = PassContext(runtime_shape)
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
_pass_context = prev_context
|
||||||
|
|
||||||
|
|
||||||
|
class InductorPass(CustomGraphPass):
|
||||||
|
"""
|
||||||
|
A custom graph pass that uses a hash of its source as the UUID.
|
||||||
|
This is defined as a convenience and should work in most cases.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def uuid(self) -> Any:
|
||||||
|
"""
|
||||||
|
Provide a unique identifier for the pass, used in Inductor code cache.
|
||||||
|
This should depend on the pass implementation, so that changes to the
|
||||||
|
pass result in recompilation.
|
||||||
|
By default, the object source is hashed.
|
||||||
|
"""
|
||||||
|
return InductorPass.hash_source(self)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def hash_source(*srcs: Union[str, Any]):
|
||||||
|
"""
|
||||||
|
Utility method to hash the sources of functions or objects.
|
||||||
|
:param srcs: strings or objects to add to the hash.
|
||||||
|
Objects and functions have their source inspected.
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
hasher = hashlib.sha256()
|
||||||
|
for src in srcs:
|
||||||
|
if isinstance(src, str):
|
||||||
|
src_str = src
|
||||||
|
elif isinstance(src, types.FunctionType):
|
||||||
|
src_str = inspect.getsource(src)
|
||||||
|
else:
|
||||||
|
src_str = inspect.getsource(src.__class__)
|
||||||
|
hasher.update(src_str.encode("utf-8"))
|
||||||
|
return hasher.hexdigest()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def hash_dict(dict_: dict[Any, Any]):
|
||||||
|
"""
|
||||||
|
Utility method to hash a dictionary, can alternatively be used for uuid.
|
||||||
|
:return: A sha256 hash of the json rep of the dictionary.
|
||||||
|
"""
|
||||||
|
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
|
||||||
|
return hashlib.sha256(encoded).hexdigest()
|
||||||
|
|
||||||
|
def is_applicable_for_shape(self, shape: Optional[int]):
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class CallableInductorPass(InductorPass):
|
||||||
|
"""
|
||||||
|
This class is a wrapper for a callable that automatically provides an
|
||||||
|
implementation of the UUID.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, callable: Callable[[fx.Graph], None], uuid: Optional[Any] = None
|
||||||
|
):
|
||||||
|
self.callable = callable
|
||||||
|
self._uuid = self.hash_source(callable) if uuid is None else uuid
|
||||||
|
|
||||||
|
def __call__(self, graph: torch.fx.Graph):
|
||||||
|
self.callable(graph)
|
||||||
|
|
||||||
|
def uuid(self) -> Any:
|
||||||
|
return self._uuid
|
||||||
|
|
||||||
|
|
||||||
|
class SGLangInductorPass(InductorPass):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
self.pass_name = self.__class__.__name__
|
||||||
|
|
||||||
|
def dump_graph(self, graph: torch.fx.Graph, stage: str):
|
||||||
|
lazy_format_graph_code(stage, graph.owning_module)
|
||||||
|
|
||||||
|
def begin(self):
|
||||||
|
self._start_time = time.perf_counter_ns()
|
||||||
|
|
||||||
|
def end_and_log(self):
|
||||||
|
self._end_time = time.perf_counter_ns()
|
||||||
|
duration_ms = float(self._end_time - self._start_time) / 1.0e6
|
||||||
|
logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms)
|
||||||
|
|
||||||
|
|
||||||
|
class PrinterInductorPass(SGLangInductorPass):
|
||||||
|
|
||||||
|
def __init__(self, name: str):
|
||||||
|
super().__init__()
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
def __call__(self, graph: torch.fx.Graph):
|
||||||
|
self.dump_graph(graph, self.name)
|
||||||
68
python/sglang/srt/model_executor/compilation/pass_manager.py
Normal file
68
python/sglang/srt/model_executor/compilation/pass_manager.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/pass_manager.py
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from torch import fx as fx
|
||||||
|
|
||||||
|
from sglang.srt.model_executor.compilation.fix_functionalization import (
|
||||||
|
FixFunctionalizationPass,
|
||||||
|
)
|
||||||
|
from sglang.srt.model_executor.compilation.inductor_pass import (
|
||||||
|
CustomGraphPass,
|
||||||
|
InductorPass,
|
||||||
|
SGLangInductorPass,
|
||||||
|
get_pass_context,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PostGradPassManager(CustomGraphPass):
|
||||||
|
"""
|
||||||
|
The pass manager for post-grad passes.
|
||||||
|
It handles configuration, adding custom passes, and running passes.
|
||||||
|
It supports uuid for the Inductor code cache. That includes torch<2.6
|
||||||
|
support using pickling (in .inductor_pass.CustomGraphPass).
|
||||||
|
|
||||||
|
The order of the post-grad post-passes is:
|
||||||
|
1. passes (constructor parameter)
|
||||||
|
2. default passes (NoopEliminationPass, FusionPass)
|
||||||
|
3. config["post_grad_custom_post_pass"] (if it exists)
|
||||||
|
4. fix_functionalization
|
||||||
|
This way, all passes operate on a functionalized graph.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.passes: list[SGLangInductorPass] = []
|
||||||
|
|
||||||
|
def __call__(self, graph: fx.Graph):
|
||||||
|
shape = get_pass_context().runtime_shape
|
||||||
|
for pass_ in self.passes:
|
||||||
|
if pass_.is_applicable_for_shape(shape):
|
||||||
|
pass_(graph)
|
||||||
|
|
||||||
|
# always run fix_functionalization last
|
||||||
|
self.fix_functionalization(graph)
|
||||||
|
|
||||||
|
def configure(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
self.pass_config = dict()
|
||||||
|
self.fix_functionalization = FixFunctionalizationPass()
|
||||||
|
|
||||||
|
def add(self, pass_: InductorPass):
|
||||||
|
assert isinstance(pass_, InductorPass)
|
||||||
|
self.passes.append(pass_)
|
||||||
|
|
||||||
|
def uuid(self):
|
||||||
|
"""
|
||||||
|
The PostGradPassManager is set as a custom pass in the Inductor and
|
||||||
|
affects compilation caching. Its uuid depends on the UUIDs of all
|
||||||
|
dependent passes and the pass config. See InductorPass for more info.
|
||||||
|
"""
|
||||||
|
pass_manager_uuid = "fshdakhsa"
|
||||||
|
state = {"pass_config": pass_manager_uuid, "passes": []}
|
||||||
|
for pass_ in self.passes:
|
||||||
|
state["passes"].append(pass_.uuid())
|
||||||
|
state["passes"].append(self.fix_functionalization.uuid())
|
||||||
|
return InductorPass.hash_dict(state)
|
||||||
@@ -0,0 +1,40 @@
|
|||||||
|
from contextlib import contextmanager
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ForwardContext:
|
||||||
|
def __init__(self):
|
||||||
|
self.forward_batch = None
|
||||||
|
self.attention_layer = None
|
||||||
|
|
||||||
|
def set_forward_batch(self, forward_batch: ForwardBatch):
|
||||||
|
self.forward_batch = forward_batch
|
||||||
|
|
||||||
|
def set_attention_layers(self, layers: List[Any]):
|
||||||
|
self.attention_layers = layers
|
||||||
|
|
||||||
|
|
||||||
|
_forward_context: Optional[ForwardContext] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_forward_context() -> Optional[ForwardContext]:
|
||||||
|
if _forward_context is None:
|
||||||
|
return None
|
||||||
|
return _forward_context
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def set_forward_context(forward_batch: ForwardBatch, attention_layers: List[Any]):
|
||||||
|
global _forward_context
|
||||||
|
prev_forward_context = _forward_context
|
||||||
|
_forward_context = ForwardContext()
|
||||||
|
_forward_context.set_forward_batch(forward_batch)
|
||||||
|
_forward_context.set_attention_layers(attention_layers)
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
_forward_context = prev_forward_context
|
||||||
@@ -0,0 +1,28 @@
|
|||||||
|
// Adapted from: https://github.com/vllm-project/vllm/blob/main/csrc/ops.h
|
||||||
|
|
||||||
|
#include <torch/extension.h>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
static at::Tensor weak_ref_tensor(at::Tensor &tensor) {
|
||||||
|
TORCH_CHECK(tensor.is_cuda(), "weak_ref_tensor expects a CUDA tensor");
|
||||||
|
|
||||||
|
void *data_ptr = tensor.data_ptr();
|
||||||
|
std::vector<int64_t> sizes = tensor.sizes().vec();
|
||||||
|
std::vector<int64_t> strides = tensor.strides().vec();
|
||||||
|
|
||||||
|
auto options = tensor.options();
|
||||||
|
|
||||||
|
auto new_tensor = torch::from_blob(data_ptr, sizes, strides, options);
|
||||||
|
|
||||||
|
return new_tensor;
|
||||||
|
}
|
||||||
|
|
||||||
|
TORCH_LIBRARY(jit_weak_ref_tensor, ops) {
|
||||||
|
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
|
||||||
|
}
|
||||||
|
|
||||||
|
TORCH_LIBRARY_IMPL(jit_weak_ref_tensor, CUDA, ops) {
|
||||||
|
ops.impl("weak_ref_tensor", weak_ref_tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {}
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils.cpp_extension import load
|
||||||
|
|
||||||
|
_abs_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
load(
|
||||||
|
name="weak_ref_tensor_ext",
|
||||||
|
sources=[f"{_abs_path}/weak_ref_tensor.cpp"],
|
||||||
|
extra_cflags=["-O3"],
|
||||||
|
)
|
||||||
|
|
||||||
|
x = torch.arange(12, device="cuda").reshape(3, 4)
|
||||||
|
y = torch.ops.jit_weak_ref_tensor.weak_ref_tensor(x)
|
||||||
|
print("alias:", x.data_ptr() == y.data_ptr())
|
||||||
@@ -108,8 +108,15 @@ from sglang.srt.mem_cache.memory_pool import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.model_executor.cpu_graph_runner import CPUGraphRunner
|
from sglang.srt.model_executor.cpu_graph_runner import CPUGraphRunner
|
||||||
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
from sglang.srt.model_executor.forward_batch_info import (
|
||||||
|
ForwardBatch,
|
||||||
|
ForwardMode,
|
||||||
|
PPProxyTensors,
|
||||||
|
)
|
||||||
from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner
|
from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner
|
||||||
|
from sglang.srt.model_executor.piecewise_cuda_graph_runner import (
|
||||||
|
PiecewiseCudaGraphRunner,
|
||||||
|
)
|
||||||
from sglang.srt.model_loader import get_model
|
from sglang.srt.model_loader import get_model
|
||||||
from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader
|
from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader
|
||||||
from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
|
from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
|
||||||
@@ -307,6 +314,26 @@ class ModelRunner:
|
|||||||
self._model_update_group = {}
|
self._model_update_group = {}
|
||||||
self._weights_send_group = {}
|
self._weights_send_group = {}
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.server_args.enable_piecewise_cuda_graph
|
||||||
|
and self.can_run_piecewise_cuda_graph()
|
||||||
|
):
|
||||||
|
self.attention_layers = []
|
||||||
|
for layer in self.model.model.layers:
|
||||||
|
if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "attn"):
|
||||||
|
self.attention_layers.append(layer.self_attn.attn)
|
||||||
|
if len(self.attention_layers) < self.model_config.num_hidden_layers:
|
||||||
|
# TODO(yuwei): support Non-Standard GQA
|
||||||
|
log_info_on_rank0(
|
||||||
|
logger,
|
||||||
|
"Disable piecewise CUDA graph because some layers do not apply Standard GQA",
|
||||||
|
)
|
||||||
|
self.piecewise_cuda_graph_runner = None
|
||||||
|
else:
|
||||||
|
self.piecewise_cuda_graph_runner = PiecewiseCudaGraphRunner(self)
|
||||||
|
else:
|
||||||
|
self.piecewise_cuda_graph_runner = None
|
||||||
|
|
||||||
def initialize(self, min_per_gpu_memory: float):
|
def initialize(self, min_per_gpu_memory: float):
|
||||||
server_args = self.server_args
|
server_args = self.server_args
|
||||||
|
|
||||||
@@ -692,6 +719,7 @@ class ModelRunner:
|
|||||||
pipeline_model_parallel_size=self.pp_size,
|
pipeline_model_parallel_size=self.pp_size,
|
||||||
expert_model_parallel_size=self.moe_ep_size,
|
expert_model_parallel_size=self.moe_ep_size,
|
||||||
duplicate_tp_group=self.server_args.enable_pdmux,
|
duplicate_tp_group=self.server_args.enable_pdmux,
|
||||||
|
torch_compile=self.server_args.enable_piecewise_cuda_graph,
|
||||||
)
|
)
|
||||||
initialize_dp_attention(
|
initialize_dp_attention(
|
||||||
server_args=self.server_args,
|
server_args=self.server_args,
|
||||||
@@ -1411,6 +1439,27 @@ class ModelRunner:
|
|||||||
f"Use Sliding window memory pool. full_layer_tokens={self.full_max_total_num_tokens}, swa_layer_tokens={self.swa_max_total_num_tokens}"
|
f"Use Sliding window memory pool. full_layer_tokens={self.full_max_total_num_tokens}, swa_layer_tokens={self.swa_max_total_num_tokens}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def can_run_piecewise_cuda_graph(self):
|
||||||
|
if self.server_args.disable_cuda_graph:
|
||||||
|
log_info_on_rank0(
|
||||||
|
logger, "Disable piecewise CUDA graph because disable_cuda_graph is set"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
if self.server_args.enable_torch_compile:
|
||||||
|
log_info_on_rank0(
|
||||||
|
logger,
|
||||||
|
"Disable piecewise CUDA graph because piecewise_cuda_graph has conflict with torch compile",
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
if self.pp_size > 1:
|
||||||
|
# TODO(yuwei): support PP
|
||||||
|
log_info_on_rank0(
|
||||||
|
logger,
|
||||||
|
"Disable piecewise CUDA graph because piecewise_cuda_graph does not support PP",
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
def init_memory_pool(
|
def init_memory_pool(
|
||||||
self,
|
self,
|
||||||
total_gpu_memory: int,
|
total_gpu_memory: int,
|
||||||
@@ -1932,6 +1981,11 @@ class ModelRunner:
|
|||||||
kwargs["input_embeds"] = forward_batch.input_embeds.bfloat16()
|
kwargs["input_embeds"] = forward_batch.input_embeds.bfloat16()
|
||||||
if not self.is_generation:
|
if not self.is_generation:
|
||||||
kwargs["get_embedding"] = True
|
kwargs["get_embedding"] = True
|
||||||
|
|
||||||
|
if self.piecewise_cuda_graph_runner is not None:
|
||||||
|
if self.piecewise_cuda_graph_runner.can_run(forward_batch):
|
||||||
|
return self.piecewise_cuda_graph_runner.replay(forward_batch, **kwargs)
|
||||||
|
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
forward_batch.input_ids,
|
forward_batch.input_ids,
|
||||||
forward_batch.positions,
|
forward_batch.positions,
|
||||||
|
|||||||
532
python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py
Normal file
532
python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py
Normal file
@@ -0,0 +1,532 @@
|
|||||||
|
# Copyright 2023-2024 SGLang Team
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Run the model with cuda graph and torch.compile."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import bisect
|
||||||
|
import gc
|
||||||
|
import logging
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import TYPE_CHECKING, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
from sglang.srt.custom_op import CustomOp
|
||||||
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||||
|
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
|
||||||
|
set_graph_pool_id,
|
||||||
|
)
|
||||||
|
from sglang.srt.distributed.parallel_state import graph_capture
|
||||||
|
from sglang.srt.layers.dp_attention import (
|
||||||
|
DpPaddingMode,
|
||||||
|
get_attention_tp_rank,
|
||||||
|
get_attention_tp_size,
|
||||||
|
set_dp_buffer_len,
|
||||||
|
)
|
||||||
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
|
from sglang.srt.layers.torchao_utils import save_gemlite_cache
|
||||||
|
from sglang.srt.model_executor.compilation.compilation_config import CompilationConfig
|
||||||
|
from sglang.srt.model_executor.compilation.compile import (
|
||||||
|
install_torch_compiled,
|
||||||
|
set_compiled,
|
||||||
|
)
|
||||||
|
from sglang.srt.model_executor.compilation.piecewise_context_manager import (
|
||||||
|
set_forward_context,
|
||||||
|
)
|
||||||
|
from sglang.srt.model_executor.forward_batch_info import (
|
||||||
|
CaptureHiddenMode,
|
||||||
|
ForwardBatch,
|
||||||
|
ForwardMode,
|
||||||
|
PPProxyTensors,
|
||||||
|
)
|
||||||
|
from sglang.srt.two_batch_overlap import TboCudaGraphRunnerPlugin
|
||||||
|
from sglang.srt.utils import get_available_gpu_memory, log_info_on_rank0
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
|
|
||||||
|
# Detect whether the current forward pass is in capture mode
|
||||||
|
is_capture_mode = False
|
||||||
|
|
||||||
|
|
||||||
|
def get_is_capture_mode():
|
||||||
|
return is_capture_mode
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def model_capture_mode():
|
||||||
|
global is_capture_mode
|
||||||
|
is_capture_mode = True
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
is_capture_mode = False
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def freeze_gc(enable_cudagraph_gc: bool):
|
||||||
|
"""
|
||||||
|
Optimize garbage collection during CUDA graph capture.
|
||||||
|
Clean up, then freeze all remaining objects from being included
|
||||||
|
in future collections if GC is disabled during capture.
|
||||||
|
"""
|
||||||
|
gc.collect()
|
||||||
|
should_freeze = not enable_cudagraph_gc
|
||||||
|
if should_freeze:
|
||||||
|
gc.freeze()
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
if should_freeze:
|
||||||
|
gc.unfreeze()
|
||||||
|
|
||||||
|
|
||||||
|
def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
|
||||||
|
for sub in model._modules.values():
|
||||||
|
if isinstance(sub, CustomOp):
|
||||||
|
if reverse:
|
||||||
|
sub.leave_torch_compile()
|
||||||
|
else:
|
||||||
|
sub.enter_torch_compile(num_tokens=num_tokens)
|
||||||
|
if isinstance(sub, torch.nn.Module):
|
||||||
|
_to_torch(sub, reverse, num_tokens)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def patch_model(model: torch.nn.Module):
|
||||||
|
try:
|
||||||
|
_to_torch(model, reverse=False, num_tokens=16)
|
||||||
|
yield model
|
||||||
|
finally:
|
||||||
|
_to_torch(model, reverse=True, num_tokens=16)
|
||||||
|
|
||||||
|
|
||||||
|
# Reuse this memory pool across all cuda graph runners.
|
||||||
|
global_graph_memory_pool = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_global_graph_memory_pool():
|
||||||
|
return global_graph_memory_pool
|
||||||
|
|
||||||
|
|
||||||
|
def set_global_graph_memory_pool(val):
|
||||||
|
global global_graph_memory_pool
|
||||||
|
global_graph_memory_pool = val
|
||||||
|
|
||||||
|
|
||||||
|
class PiecewiseCudaGraphRunner:
|
||||||
|
"""A PiecewiseCudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
|
||||||
|
|
||||||
|
def __init__(self, model_runner: ModelRunner):
|
||||||
|
# Parse args
|
||||||
|
self.model_runner = model_runner
|
||||||
|
self.device = model_runner.device
|
||||||
|
self.device_module = torch.get_device_module(self.device)
|
||||||
|
self.graphs = {}
|
||||||
|
self.output_buffers = {}
|
||||||
|
self.tp_size = model_runner.server_args.tp_size
|
||||||
|
self.dp_size = model_runner.server_args.dp_size
|
||||||
|
self.pp_size = model_runner.server_args.pp_size
|
||||||
|
|
||||||
|
self.attn_tp_size = get_attention_tp_size()
|
||||||
|
self.attn_tp_rank = get_attention_tp_rank()
|
||||||
|
|
||||||
|
assert (
|
||||||
|
self.model_runner.server_args.piecewise_cuda_graph_tokens is not None
|
||||||
|
), "piecewise_cuda_graph_tokens is not set"
|
||||||
|
self.compile_config = CompilationConfig(
|
||||||
|
self.model_runner.server_args.piecewise_cuda_graph_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
# Batch sizes to capture
|
||||||
|
self.capture_num_tokens = self.compile_config.get_capture_sizes()
|
||||||
|
log_info_on_rank0(
|
||||||
|
logger, f"Capture cuda graph num tokens {self.capture_num_tokens}"
|
||||||
|
)
|
||||||
|
self.capture_forward_mode = ForwardMode.EXTEND
|
||||||
|
self.capture_hidden_mode = CaptureHiddenMode.NULL
|
||||||
|
|
||||||
|
# If returning hidden states is enabled, set initial capture hidden mode to full to avoid double-capture on startup
|
||||||
|
if model_runner.server_args.enable_return_hidden_states:
|
||||||
|
self.capture_hidden_mode = CaptureHiddenMode.FULL
|
||||||
|
|
||||||
|
# Attention backend
|
||||||
|
self.max_num_tokens = max(self.capture_num_tokens)
|
||||||
|
|
||||||
|
# Graph inputs
|
||||||
|
with torch.device(self.device):
|
||||||
|
self.input_ids = torch.zeros((self.max_num_tokens,), dtype=torch.int64)
|
||||||
|
self.out_cache_loc = torch.zeros(
|
||||||
|
(self.max_num_tokens,), dtype=self._cache_loc_dtype()
|
||||||
|
)
|
||||||
|
self.positions = torch.zeros((self.max_num_tokens,), dtype=torch.int64)
|
||||||
|
self.tbo_plugin = TboCudaGraphRunnerPlugin()
|
||||||
|
|
||||||
|
self.attention_layers = self.model_runner.attention_layers
|
||||||
|
|
||||||
|
if get_global_graph_memory_pool() is None:
|
||||||
|
set_global_graph_memory_pool(self.device_module.graph_pool_handle())
|
||||||
|
# Set graph pool id globally to be able to use symmetric memory
|
||||||
|
set_graph_pool_id(get_global_graph_memory_pool())
|
||||||
|
|
||||||
|
with patch_model(self.model_runner.model.model) as patched_model:
|
||||||
|
install_torch_compiled(
|
||||||
|
patched_model,
|
||||||
|
fullgraph=True,
|
||||||
|
dynamic_arg_dims=None,
|
||||||
|
compile_config=self.compile_config,
|
||||||
|
graph_pool=get_global_graph_memory_pool(),
|
||||||
|
)
|
||||||
|
|
||||||
|
with set_compiled(True):
|
||||||
|
self.warmup_and_capture()
|
||||||
|
|
||||||
|
# Capture
|
||||||
|
try:
|
||||||
|
with model_capture_mode():
|
||||||
|
self.capture()
|
||||||
|
except RuntimeError as e:
|
||||||
|
raise Exception(
|
||||||
|
f"Capture cuda graph failed: {e}\n{PIECEWISE_CUDA_GRAPH_CAPTURE_FAILED_MSG}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.raw_num_tokens = 0
|
||||||
|
|
||||||
|
def warmup_and_capture(self):
|
||||||
|
num_tokens = 2
|
||||||
|
with torch.device(self.device):
|
||||||
|
forward_batch = ForwardBatch(
|
||||||
|
forward_mode=ForwardMode.EXTEND,
|
||||||
|
batch_size=1,
|
||||||
|
input_ids=torch.randint(0, 100, (num_tokens,), device=self.device),
|
||||||
|
req_pool_indices=torch.arange(1, device=self.device),
|
||||||
|
seq_lens=torch.tensor([num_tokens], device=self.device),
|
||||||
|
next_token_logits_buffer=None,
|
||||||
|
orig_seq_lens=torch.tensor([num_tokens], device=self.device),
|
||||||
|
seq_lens_cpu=torch.tensor([num_tokens]),
|
||||||
|
req_to_token_pool=self.model_runner.req_to_token_pool,
|
||||||
|
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
||||||
|
attn_backend=self.model_runner.attn_backend,
|
||||||
|
out_cache_loc=torch.randint(0, 100, (num_tokens,), device=self.device),
|
||||||
|
seq_lens_sum=num_tokens,
|
||||||
|
encoder_lens=None,
|
||||||
|
return_logprob=False,
|
||||||
|
extend_seq_lens=torch.tensor([num_tokens], device=self.device),
|
||||||
|
extend_prefix_lens=torch.tensor([num_tokens], device=self.device),
|
||||||
|
extend_start_loc=torch.tensor([0], device=self.device),
|
||||||
|
extend_prefix_lens_cpu=torch.tensor([num_tokens]),
|
||||||
|
extend_seq_lens_cpu=torch.tensor([num_tokens]),
|
||||||
|
extend_logprob_start_lens_cpu=torch.tensor([num_tokens]),
|
||||||
|
positions=torch.arange(num_tokens, device=self.device),
|
||||||
|
global_num_tokens_gpu=None,
|
||||||
|
global_num_tokens_for_logprob_gpu=None,
|
||||||
|
dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
|
||||||
|
global_dp_buffer_len=None,
|
||||||
|
mrope_positions=None,
|
||||||
|
spec_algorithm=None,
|
||||||
|
spec_info=None,
|
||||||
|
capture_hidden_mode=CaptureHiddenMode.NULL,
|
||||||
|
num_token_non_padded=None,
|
||||||
|
global_forward_mode=ForwardMode.EXTEND,
|
||||||
|
lora_ids=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
with set_forward_context(forward_batch, self.attention_layers):
|
||||||
|
_ = self.model_runner.model.forward(
|
||||||
|
forward_batch.input_ids,
|
||||||
|
forward_batch.positions,
|
||||||
|
forward_batch,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _cache_loc_dtype(self):
|
||||||
|
return torch.int64
|
||||||
|
|
||||||
|
def can_run(self, forward_batch: ForwardBatch):
|
||||||
|
num_tokens = len(forward_batch.input_ids)
|
||||||
|
# TODO(yuwei): support return logprob
|
||||||
|
if forward_batch.return_logprob:
|
||||||
|
return False
|
||||||
|
if num_tokens <= self.max_num_tokens:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def capture(self) -> None:
|
||||||
|
# Trigger CUDA graph capture for specific shapes.
|
||||||
|
# Capture the large shapes first so that the smaller shapes
|
||||||
|
# can reuse the memory pool allocated for the large shapes.
|
||||||
|
with freeze_gc(
|
||||||
|
self.model_runner.server_args.enable_cudagraph_gc
|
||||||
|
), graph_capture() as graph_capture_context:
|
||||||
|
self.stream = graph_capture_context.stream
|
||||||
|
avail_mem = get_available_gpu_memory(
|
||||||
|
self.model_runner.device,
|
||||||
|
self.model_runner.gpu_id,
|
||||||
|
empty_cache=False,
|
||||||
|
)
|
||||||
|
# Reverse the order to enable better memory sharing across cuda graphs.
|
||||||
|
capture_range = (
|
||||||
|
tqdm.tqdm(list(reversed(self.capture_num_tokens)))
|
||||||
|
if get_tensor_model_parallel_rank() == 0
|
||||||
|
else reversed(self.capture_num_tokens)
|
||||||
|
)
|
||||||
|
for i, num_tokens in enumerate(capture_range):
|
||||||
|
if get_tensor_model_parallel_rank() == 0:
|
||||||
|
avail_mem = get_available_gpu_memory(
|
||||||
|
self.model_runner.device,
|
||||||
|
self.model_runner.gpu_id,
|
||||||
|
empty_cache=False,
|
||||||
|
)
|
||||||
|
capture_range.set_description(
|
||||||
|
f"Capturing num tokens ({num_tokens=} {avail_mem=:.2f} GB)"
|
||||||
|
)
|
||||||
|
|
||||||
|
with set_compiled(True):
|
||||||
|
self.capture_one_batch_size(num_tokens)
|
||||||
|
|
||||||
|
# Save gemlite cache after each capture
|
||||||
|
save_gemlite_cache()
|
||||||
|
|
||||||
|
def capture_one_batch_size(self, num_tokens: int):
|
||||||
|
stream = self.stream
|
||||||
|
bs = 1
|
||||||
|
|
||||||
|
# Graph inputs
|
||||||
|
input_ids = self.input_ids[:num_tokens]
|
||||||
|
out_cache_loc = self.out_cache_loc[:num_tokens]
|
||||||
|
positions = self.positions[:num_tokens]
|
||||||
|
|
||||||
|
# pipeline parallelism
|
||||||
|
if self.pp_size > 1:
|
||||||
|
pp_proxy_tensors = PPProxyTensors(
|
||||||
|
{k: v[:num_tokens] for k, v in self.pp_proxy_tensors.items()}
|
||||||
|
)
|
||||||
|
|
||||||
|
global_dp_buffer_len = None
|
||||||
|
|
||||||
|
if self.model_runner.server_args.enable_lora:
|
||||||
|
# It is safe to capture CUDA graph using empty LoRA id, as the LoRA kernels will always be launched whenever
|
||||||
|
# `--enable-lora` is set to True (and return immediately if the LoRA id is empty for perf optimization).
|
||||||
|
lora_ids = [None] * bs
|
||||||
|
else:
|
||||||
|
lora_ids = None
|
||||||
|
|
||||||
|
with torch.device(self.device):
|
||||||
|
forward_batch = ForwardBatch(
|
||||||
|
forward_mode=ForwardMode.EXTEND,
|
||||||
|
batch_size=bs,
|
||||||
|
input_ids=input_ids,
|
||||||
|
req_pool_indices=torch.arange(bs, device=self.device),
|
||||||
|
seq_lens=torch.tensor([num_tokens], device=self.device),
|
||||||
|
next_token_logits_buffer=None,
|
||||||
|
orig_seq_lens=torch.tensor([num_tokens], device=self.device),
|
||||||
|
seq_lens_cpu=torch.tensor([num_tokens]),
|
||||||
|
req_to_token_pool=self.model_runner.req_to_token_pool,
|
||||||
|
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
||||||
|
attn_backend=self.model_runner.attn_backend,
|
||||||
|
out_cache_loc=out_cache_loc,
|
||||||
|
seq_lens_sum=num_tokens,
|
||||||
|
encoder_lens=None,
|
||||||
|
return_logprob=False,
|
||||||
|
extend_seq_lens=torch.tensor([num_tokens], device=self.device),
|
||||||
|
extend_prefix_lens=torch.tensor([num_tokens], device=self.device),
|
||||||
|
extend_start_loc=torch.tensor([0], device=self.device),
|
||||||
|
extend_prefix_lens_cpu=torch.tensor([num_tokens]),
|
||||||
|
extend_seq_lens_cpu=torch.tensor([num_tokens]),
|
||||||
|
extend_logprob_start_lens_cpu=torch.tensor([num_tokens]),
|
||||||
|
positions=positions,
|
||||||
|
global_num_tokens_gpu=None,
|
||||||
|
global_num_tokens_for_logprob_gpu=None,
|
||||||
|
dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
|
||||||
|
global_dp_buffer_len=None,
|
||||||
|
mrope_positions=None,
|
||||||
|
spec_algorithm=None,
|
||||||
|
spec_info=None,
|
||||||
|
capture_hidden_mode=CaptureHiddenMode.NULL,
|
||||||
|
num_token_non_padded=None,
|
||||||
|
global_forward_mode=ForwardMode.EXTEND,
|
||||||
|
lora_ids=None,
|
||||||
|
)
|
||||||
|
self.tbo_plugin.capture_one_batch_size(forward_batch, num_tokens=num_tokens)
|
||||||
|
|
||||||
|
if lora_ids is not None:
|
||||||
|
self.model_runner.lora_manager.prepare_lora_batch(forward_batch)
|
||||||
|
|
||||||
|
# # Attention backend
|
||||||
|
self.model_runner.attn_backend.init_forward_metadata(forward_batch)
|
||||||
|
|
||||||
|
# Run and capture
|
||||||
|
def run_once():
|
||||||
|
# Clean intermediate result cache for DP attention
|
||||||
|
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
|
||||||
|
set_dp_buffer_len(global_dp_buffer_len, num_tokens)
|
||||||
|
|
||||||
|
kwargs = {}
|
||||||
|
with set_forward_context(forward_batch, self.attention_layers):
|
||||||
|
self.model_runner.model.forward(
|
||||||
|
forward_batch.input_ids,
|
||||||
|
forward_batch.positions,
|
||||||
|
forward_batch,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
for _ in range(2):
|
||||||
|
self.device_module.synchronize()
|
||||||
|
self.model_runner.tp_group.barrier()
|
||||||
|
run_once()
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
def replay_prepare(
|
||||||
|
self,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
num_tokens = len(forward_batch.input_ids)
|
||||||
|
index = bisect.bisect_left(self.capture_num_tokens, num_tokens)
|
||||||
|
static_num_tokens = self.capture_num_tokens[index]
|
||||||
|
self.raw_num_tokens = num_tokens
|
||||||
|
if static_num_tokens != num_tokens:
|
||||||
|
self.out_cache_loc.zero_()
|
||||||
|
bs = forward_batch.batch_size
|
||||||
|
|
||||||
|
self.input_ids[:num_tokens].copy_(forward_batch.input_ids)
|
||||||
|
self.positions[:num_tokens].copy_(forward_batch.positions)
|
||||||
|
self.out_cache_loc[:num_tokens].copy_(forward_batch.out_cache_loc)
|
||||||
|
|
||||||
|
input_ids = self.input_ids[:static_num_tokens]
|
||||||
|
positions = self.positions[:static_num_tokens]
|
||||||
|
out_cache_loc = self.out_cache_loc[:static_num_tokens]
|
||||||
|
|
||||||
|
next_token_logits_buffer = None
|
||||||
|
mrope_positions = None
|
||||||
|
|
||||||
|
static_forward_batch = ForwardBatch(
|
||||||
|
forward_mode=forward_batch.forward_mode,
|
||||||
|
batch_size=bs,
|
||||||
|
input_ids=input_ids,
|
||||||
|
req_pool_indices=forward_batch.req_pool_indices,
|
||||||
|
seq_lens=forward_batch.seq_lens,
|
||||||
|
next_token_logits_buffer=next_token_logits_buffer,
|
||||||
|
orig_seq_lens=forward_batch.orig_seq_lens,
|
||||||
|
seq_lens_cpu=forward_batch.seq_lens_cpu,
|
||||||
|
req_to_token_pool=self.model_runner.req_to_token_pool,
|
||||||
|
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
||||||
|
attn_backend=self.model_runner.attn_backend,
|
||||||
|
out_cache_loc=out_cache_loc,
|
||||||
|
seq_lens_sum=forward_batch.seq_lens_sum,
|
||||||
|
encoder_lens=forward_batch.encoder_lens,
|
||||||
|
return_logprob=forward_batch.return_logprob,
|
||||||
|
extend_seq_lens=forward_batch.extend_seq_lens,
|
||||||
|
extend_prefix_lens=forward_batch.extend_prefix_lens,
|
||||||
|
extend_start_loc=forward_batch.extend_start_loc,
|
||||||
|
extend_prefix_lens_cpu=forward_batch.extend_prefix_lens_cpu,
|
||||||
|
extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
|
||||||
|
extend_logprob_start_lens_cpu=forward_batch.extend_logprob_start_lens_cpu,
|
||||||
|
extend_num_tokens=forward_batch.extend_num_tokens,
|
||||||
|
extend_input_logprob_token_ids_gpu=forward_batch.extend_input_logprob_token_ids_gpu,
|
||||||
|
positions=positions,
|
||||||
|
global_num_tokens_gpu=forward_batch.global_num_tokens_gpu,
|
||||||
|
global_num_tokens_for_logprob_gpu=forward_batch.global_num_tokens_for_logprob_gpu,
|
||||||
|
dp_padding_mode=forward_batch.dp_padding_mode,
|
||||||
|
global_dp_buffer_len=forward_batch.global_dp_buffer_len,
|
||||||
|
mrope_positions=mrope_positions,
|
||||||
|
spec_algorithm=forward_batch.spec_algorithm,
|
||||||
|
spec_info=forward_batch.spec_info,
|
||||||
|
capture_hidden_mode=forward_batch.capture_hidden_mode,
|
||||||
|
num_token_non_padded=forward_batch.num_token_non_padded,
|
||||||
|
global_forward_mode=forward_batch.global_forward_mode,
|
||||||
|
lora_ids=forward_batch.lora_ids,
|
||||||
|
sampling_info=forward_batch.sampling_info,
|
||||||
|
mm_inputs=forward_batch.mm_inputs,
|
||||||
|
temp_scaled_logprobs=forward_batch.temp_scaled_logprobs,
|
||||||
|
temperature=forward_batch.temperature,
|
||||||
|
top_p_normalized_logprobs=forward_batch.top_p_normalized_logprobs,
|
||||||
|
top_p=forward_batch.top_p,
|
||||||
|
)
|
||||||
|
|
||||||
|
return static_forward_batch
|
||||||
|
|
||||||
|
def replay(
|
||||||
|
self,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
**kwargs,
|
||||||
|
) -> Union[LogitsProcessorOutput, PPProxyTensors]:
|
||||||
|
static_forward_batch = self.replay_prepare(forward_batch, **kwargs)
|
||||||
|
# Replay
|
||||||
|
with set_forward_context(static_forward_batch, self.attention_layers):
|
||||||
|
with set_compiled(True):
|
||||||
|
output = self.model_runner.model.forward(
|
||||||
|
static_forward_batch.input_ids,
|
||||||
|
static_forward_batch.positions,
|
||||||
|
static_forward_batch,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
if isinstance(output, LogitsProcessorOutput):
|
||||||
|
return LogitsProcessorOutput(
|
||||||
|
next_token_logits=output.next_token_logits[: self.raw_num_tokens],
|
||||||
|
hidden_states=(
|
||||||
|
output.hidden_states[: self.raw_num_tokens]
|
||||||
|
if output.hidden_states is not None
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert isinstance(output, PPProxyTensors)
|
||||||
|
# TODO(Yuwei): support PP Support
|
||||||
|
raise NotImplementedError(
|
||||||
|
"PPProxyTensors is not supported in PiecewiseCudaGraphRunner yet."
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_spec_info(self, num_tokens: int):
|
||||||
|
spec_info = None
|
||||||
|
if (
|
||||||
|
self.model_runner.spec_algorithm.is_eagle()
|
||||||
|
or self.model_runner.spec_algorithm.is_standalone()
|
||||||
|
):
|
||||||
|
from sglang.srt.speculative.eagle_utils import EagleVerifyInput
|
||||||
|
|
||||||
|
if self.model_runner.is_draft_worker:
|
||||||
|
raise RuntimeError("This should not happen.")
|
||||||
|
else:
|
||||||
|
spec_info = EagleVerifyInput(
|
||||||
|
draft_token=None,
|
||||||
|
custom_mask=self.custom_mask,
|
||||||
|
positions=None,
|
||||||
|
retrive_index=None,
|
||||||
|
retrive_next_token=None,
|
||||||
|
retrive_next_sibling=None,
|
||||||
|
retrive_cum_len=None,
|
||||||
|
spec_steps=self.model_runner.server_args.speculative_num_steps,
|
||||||
|
topk=self.model_runner.server_args.speculative_eagle_topk,
|
||||||
|
draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens,
|
||||||
|
capture_hidden_mode=CaptureHiddenMode.FULL,
|
||||||
|
seq_lens_sum=None,
|
||||||
|
seq_lens_cpu=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
return spec_info
|
||||||
|
|
||||||
|
|
||||||
|
PIECEWISE_CUDA_GRAPH_CAPTURE_FAILED_MSG = (
|
||||||
|
"Possible solutions:\n"
|
||||||
|
"1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
|
||||||
|
"2. set --piecewise-cuda-graph-max-tokens to a smaller value (e.g., 512)\n"
|
||||||
|
"3. disable Piecewise CUDA graph by unset --enable-piecewise-cuda-graph\n"
|
||||||
|
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
|
||||||
|
)
|
||||||
@@ -417,7 +417,10 @@ class ServerArgs:
|
|||||||
enable_single_batch_overlap: bool = False
|
enable_single_batch_overlap: bool = False
|
||||||
tbo_token_distribution_threshold: float = 0.48
|
tbo_token_distribution_threshold: float = 0.48
|
||||||
enable_torch_compile: bool = False
|
enable_torch_compile: bool = False
|
||||||
|
enable_piecewise_cuda_graph: bool = False
|
||||||
torch_compile_max_bs: int = 32
|
torch_compile_max_bs: int = 32
|
||||||
|
piecewise_cuda_graph_max_tokens: int = 4096
|
||||||
|
piecewise_cuda_graph_tokens: Optional[List[int]] = None
|
||||||
torchao_config: str = ""
|
torchao_config: str = ""
|
||||||
enable_nan_detection: bool = False
|
enable_nan_detection: bool = False
|
||||||
enable_p2p_check: bool = False
|
enable_p2p_check: bool = False
|
||||||
@@ -675,6 +678,11 @@ class ServerArgs:
|
|||||||
else:
|
else:
|
||||||
self.cuda_graph_max_bs = max(self.cuda_graph_bs)
|
self.cuda_graph_max_bs = max(self.cuda_graph_bs)
|
||||||
|
|
||||||
|
if self.piecewise_cuda_graph_tokens is None:
|
||||||
|
self.piecewise_cuda_graph_tokens = (
|
||||||
|
self._generate_piecewise_cuda_graph_tokens()
|
||||||
|
)
|
||||||
|
|
||||||
if self.mem_fraction_static is None:
|
if self.mem_fraction_static is None:
|
||||||
# Constant meta data (e.g., from attention backend)
|
# Constant meta data (e.g., from attention backend)
|
||||||
reserved_mem = 512
|
reserved_mem = 512
|
||||||
@@ -753,6 +761,25 @@ class ServerArgs:
|
|||||||
|
|
||||||
return capture_bs
|
return capture_bs
|
||||||
|
|
||||||
|
def _generate_piecewise_cuda_graph_tokens(self):
|
||||||
|
"""
|
||||||
|
Generate the list of batch sizes for piecewise CUDA graph capture
|
||||||
|
based on piecewise_cuda_graph_max_tokens.
|
||||||
|
"""
|
||||||
|
capture_sizes = (
|
||||||
|
list(range(4, 33, 4))
|
||||||
|
+ list(range(48, 257, 16))
|
||||||
|
+ list(range(288, 513, 32))
|
||||||
|
+ list(range(640, 4096 + 1, 128))
|
||||||
|
+ list(range(4352, self.piecewise_cuda_graph_max_tokens + 1, 256))
|
||||||
|
)
|
||||||
|
|
||||||
|
capture_sizes = [
|
||||||
|
s for s in capture_sizes if s <= self.piecewise_cuda_graph_max_tokens
|
||||||
|
]
|
||||||
|
|
||||||
|
return capture_sizes
|
||||||
|
|
||||||
def _handle_hpu_backends(self):
|
def _handle_hpu_backends(self):
|
||||||
if self.device == "hpu":
|
if self.device == "hpu":
|
||||||
self.attention_backend = "torch_native"
|
self.attention_backend = "torch_native"
|
||||||
@@ -2649,12 +2676,29 @@ class ServerArgs:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Optimize the model with torch.compile. Experimental feature.",
|
help="Optimize the model with torch.compile. Experimental feature.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-piecewise-cuda-graph",
|
||||||
|
action="store_true",
|
||||||
|
help="Optimize the model with piecewise cuda graph for extend/prefill only. Experimental feature.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--piecewise-cuda-graph-tokens",
|
||||||
|
type=json_list_type,
|
||||||
|
default=ServerArgs.piecewise_cuda_graph_tokens,
|
||||||
|
help="Set the list of tokens when using piecewise cuda graph.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--torch-compile-max-bs",
|
"--torch-compile-max-bs",
|
||||||
type=int,
|
type=int,
|
||||||
default=ServerArgs.torch_compile_max_bs,
|
default=ServerArgs.torch_compile_max_bs,
|
||||||
help="Set the maximum batch size when using torch compile.",
|
help="Set the maximum batch size when using torch compile.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--piecewise-cuda-graph-max-tokens",
|
||||||
|
type=int,
|
||||||
|
default=ServerArgs.piecewise_cuda_graph_max_tokens,
|
||||||
|
help="Set the maximum tokens when using piecewise cuda graph.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--torchao-config",
|
"--torchao-config",
|
||||||
type=str,
|
type=str,
|
||||||
|
|||||||
59
test/srt/test_piecewise_cuda_graph.py
Normal file
59
test/srt/test_piecewise_cuda_graph.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
import time
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from sglang.srt.utils import kill_process_tree
|
||||||
|
from sglang.test.run_eval import run_eval
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
CustomTestCase,
|
||||||
|
SimpleNamespace,
|
||||||
|
popen_launch_server,
|
||||||
|
run_bench_one_batch,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPiecewiseCudaGraphCorrectness(CustomTestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=["--enable-piecewise-cuda-graph"],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_process_tree(cls.process.pid)
|
||||||
|
|
||||||
|
def test_gpqa(self):
|
||||||
|
args = SimpleNamespace(
|
||||||
|
base_url=self.base_url,
|
||||||
|
model=self.model,
|
||||||
|
eval_name="gpqa",
|
||||||
|
num_examples=64,
|
||||||
|
num_threads=16,
|
||||||
|
)
|
||||||
|
|
||||||
|
metrics = run_eval(args)
|
||||||
|
self.assertGreaterEqual(metrics["score"], 0.235)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPiecewiseCudaGraphBenchmark(CustomTestCase):
|
||||||
|
|
||||||
|
def test_latency(self):
|
||||||
|
prefill_latency, _, _ = run_bench_one_batch(
|
||||||
|
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||||
|
other_args=["--enable-piecewise-cuda-graph"],
|
||||||
|
)
|
||||||
|
self.assertLess(prefill_latency, 0.015)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user