forked from EngineX-Cambricon/enginex-mlu370-vllm
add qwen3
This commit is contained in:
0
vllm-v0.6.2/vllm/compilation/__init__.py
Normal file
0
vllm-v0.6.2/vllm/compilation/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
vllm-v0.6.2/vllm/compilation/__pycache__/levels.cpython-310.pyc
Normal file
BIN
vllm-v0.6.2/vllm/compilation/__pycache__/levels.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm-v0.6.2/vllm/compilation/__pycache__/wrapper.cpython-310.pyc
Normal file
BIN
vllm-v0.6.2/vllm/compilation/__pycache__/wrapper.cpython-310.pyc
Normal file
Binary file not shown.
691
vllm-v0.6.2/vllm/compilation/backends.py
Normal file
691
vllm-v0.6.2/vllm/compilation/backends.py
Normal file
@@ -0,0 +1,691 @@
|
||||
import copy
|
||||
import dataclasses
|
||||
import operator
|
||||
from contextlib import ExitStack
|
||||
from typing import (Any, Callable, Dict, List, Optional, Sequence, Set, Tuple,
|
||||
Union)
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import combine_fx_passes, weak_ref_tensors
|
||||
|
||||
from .config import CompilationConfig
|
||||
from .counter import compilation_counter
|
||||
from .fusion import FusionPass
|
||||
from .levels import CompilationLevel
|
||||
from .reshapes import RedundantReshapesPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def fix_functionalization(graph: fx.Graph):
|
||||
"""
|
||||
Rewrite the graph module to replace the pattern involving
|
||||
torch._higher_order_ops.auto_functionalize.auto_functionalized
|
||||
with a direct call to the inplace custom op.
|
||||
|
||||
# TODO: check if PyTorch nightly has fixed this issue
|
||||
"""
|
||||
|
||||
# debug code, if we want to see the graph before the transformation
|
||||
# with open("before.py", "w") as f:
|
||||
# print(graph.python_code(root_module="self", verbose=True).src, file=f)
|
||||
|
||||
nodes_to_remove = []
|
||||
|
||||
for node in graph.nodes:
|
||||
# Identify the auto_functionalized node
|
||||
if node.op == 'call_function' and node.target == torch._higher_order_ops.auto_functionalize.auto_functionalized: # noqa
|
||||
if node.args[0] == torch.ops._C.rotary_embedding.default:
|
||||
# manual replace for rotary_embedding
|
||||
|
||||
# Now, collect the arguments
|
||||
kwargs = node.kwargs
|
||||
|
||||
query = kwargs['query']
|
||||
mm_node = query.args[0].args[0]
|
||||
|
||||
# Create a new call to torch.ops._C.rotary_embedding.default
|
||||
with graph.inserting_before(node):
|
||||
# just insert the call to the custom op
|
||||
# NOTE: don't run dead code elimination,
|
||||
# otherwise this op will be removed
|
||||
graph.call_function(torch.ops._C.rotary_embedding.default,
|
||||
kwargs=kwargs)
|
||||
|
||||
# Remove the auto_functionalized node
|
||||
# Since the node may have outputs, we need to handle its users
|
||||
# Replace uses of the outputs (getitem nodes) with mm_node
|
||||
for user in list(node.users):
|
||||
if user.op == 'call_function' and user.target == operator.getitem: # noqa
|
||||
# Remove the getitem node
|
||||
for getitem_user in list(user.users):
|
||||
if (getitem_user.op == 'call_function'
|
||||
and getitem_user.target
|
||||
== torch.ops.aten.slice_scatter.default):
|
||||
# Replace the uses of slice_scatter node
|
||||
# with mm_node
|
||||
getitem_user.replace_all_uses_with(mm_node)
|
||||
nodes_to_remove.append(getitem_user)
|
||||
nodes_to_remove.append(user)
|
||||
nodes_to_remove.append(node)
|
||||
|
||||
elif node.args[0] == torch.ops._C.fused_add_rms_norm.default:
|
||||
# manual replace for fused_add_rms_norm
|
||||
# this is the most effective optimization for llama
|
||||
# failing to do this will result in many unnecessary copies
|
||||
|
||||
kwargs = node.kwargs
|
||||
|
||||
input = kwargs['input']
|
||||
residual = kwargs['residual']
|
||||
|
||||
# Create a new call to torch.ops._C.rotary_embedding.default
|
||||
with graph.inserting_before(node):
|
||||
# just insert the call to the custom op
|
||||
# NOTE: don't run dead code elimination,
|
||||
# otherwise this op will be removed
|
||||
graph.call_function(
|
||||
torch.ops._C.fused_add_rms_norm.default, kwargs=kwargs)
|
||||
|
||||
for user in list(node.users):
|
||||
if user.op == 'call_function' and user.target == operator.getitem: # noqa
|
||||
# Remove the getitem node
|
||||
if user.args[1] == 1:
|
||||
replace_node = input
|
||||
elif user.args[1] == 2:
|
||||
replace_node = residual
|
||||
user.replace_all_uses_with(replace_node)
|
||||
nodes_to_remove.append(user)
|
||||
nodes_to_remove.append(node)
|
||||
elif (node.args[0] ==
|
||||
torch.ops._C.fused_add_rms_norm_static_fp8_quant.default):
|
||||
# manual replace for fused_add_rms_norm_static_fp8_quant
|
||||
# this is the most effective optimization for llama
|
||||
# failing to do this will result in many unnecessary copies
|
||||
|
||||
kwargs = node.kwargs
|
||||
|
||||
result = kwargs['result']
|
||||
residual = kwargs['residual']
|
||||
|
||||
# Create a new call to
|
||||
# torch.ops._C.fused_add_rms_norm_static_fp8_quant.default
|
||||
with graph.inserting_before(node):
|
||||
# just insert the call to the custom op
|
||||
# NOTE: don't run dead code elimination,
|
||||
# otherwise this op will be removed
|
||||
graph.call_function(
|
||||
torch.ops._C.fused_add_rms_norm_static_fp8_quant.
|
||||
default,
|
||||
kwargs=kwargs)
|
||||
|
||||
for user in list(node.users):
|
||||
if user.op == 'call_function' and user.target == operator.getitem: # noqa
|
||||
# Remove the getitem node
|
||||
if user.args[1] == 1:
|
||||
replace_node = result
|
||||
elif user.args[1] == 2:
|
||||
replace_node = residual
|
||||
user.replace_all_uses_with(replace_node)
|
||||
nodes_to_remove.append(user)
|
||||
nodes_to_remove.append(node)
|
||||
|
||||
elif node.args[0] == torch.ops._C.rms_norm.default:
|
||||
# manual replace for rms_norm
|
||||
|
||||
kwargs = node.kwargs
|
||||
|
||||
replace_node = kwargs['result']
|
||||
# Create a new call to torch.ops._C.rms_norm.default
|
||||
with graph.inserting_before(node):
|
||||
# just insert the call to the custom op
|
||||
# NOTE: don't run dead code elimination,
|
||||
# otherwise this op will be removed
|
||||
graph.call_function(torch.ops._C.rms_norm.default,
|
||||
kwargs=kwargs)
|
||||
|
||||
for user in list(node.users):
|
||||
if user.op == 'call_function' and user.target == operator.getitem: # noqa
|
||||
user.replace_all_uses_with(replace_node)
|
||||
nodes_to_remove.append(user)
|
||||
nodes_to_remove.append(node)
|
||||
|
||||
elif node.args[
|
||||
0] == torch.ops._C.rms_norm_static_fp8_quant.default: # noqa
|
||||
# manual replace for rms_norm_static_fp8_quant
|
||||
|
||||
kwargs = node.kwargs
|
||||
|
||||
replace_node = kwargs['result']
|
||||
# Create a new call to torch.ops._C.rms_norm_static_fp8_quant.default # noqa
|
||||
with graph.inserting_before(node):
|
||||
# just insert the call to the custom op
|
||||
# NOTE: don't run dead code elimination,
|
||||
# otherwise this op will be removed
|
||||
graph.call_function(
|
||||
torch.ops._C.rms_norm_static_fp8_quant.default,
|
||||
kwargs=kwargs)
|
||||
|
||||
for user in list(node.users):
|
||||
if user.op == 'call_function' and user.target == operator.getitem: # noqa
|
||||
user.replace_all_uses_with(replace_node)
|
||||
nodes_to_remove.append(user)
|
||||
nodes_to_remove.append(node)
|
||||
|
||||
elif node.args[0] == torch.ops._C.silu_and_mul.default:
|
||||
# manual replace for silu_and_mul
|
||||
|
||||
kwargs = node.kwargs
|
||||
|
||||
input = kwargs['input']
|
||||
out = kwargs['out']
|
||||
|
||||
# Create a new call to torch.ops._C.silu_and_mul.default
|
||||
# cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa
|
||||
with graph.inserting_before(node):
|
||||
# just insert the call to the custom op
|
||||
# NOTE: don't run dead code elimination,
|
||||
# otherwise this op will be removed
|
||||
graph.call_function(
|
||||
torch.ops._C.silu_and_mul.default,
|
||||
args=(out, input),
|
||||
)
|
||||
replace_node = out
|
||||
|
||||
for user in list(node.users):
|
||||
if user.op == 'call_function' and user.target == operator.getitem: # noqa
|
||||
user.replace_all_uses_with(replace_node)
|
||||
nodes_to_remove.append(user)
|
||||
nodes_to_remove.append(node)
|
||||
|
||||
# Remove the nodes all at once
|
||||
for node in nodes_to_remove:
|
||||
graph.erase_node(node)
|
||||
|
||||
# debug code, if we want to see the graph after the transformation
|
||||
# with open("after.py", "w") as f:
|
||||
# print(graph.python_code(root_module="self", verbose=True).src, file=f)
|
||||
|
||||
|
||||
def wrap_inductor(graph,
|
||||
example_inputs,
|
||||
additional_inductor_config,
|
||||
do_logging=False,
|
||||
runtime_shape: Optional[int] = None,
|
||||
use_inductor: bool = True):
|
||||
if not use_inductor:
|
||||
return graph
|
||||
|
||||
compilation_counter.num_inductor_compilations += 1
|
||||
|
||||
if do_logging:
|
||||
if runtime_shape is None:
|
||||
logger.info("Compiling a graph for general shape")
|
||||
else:
|
||||
logger.info("Compiling a graph for shape %s", runtime_shape)
|
||||
|
||||
from torch._inductor import config
|
||||
current_config = config.shallow_copy_dict()
|
||||
from torch._inductor.compile_fx import compile_fx
|
||||
|
||||
if additional_inductor_config is not None:
|
||||
current_config.update(additional_inductor_config)
|
||||
|
||||
# inductor can inplace modify the graph, so we need to copy it
|
||||
# see https://github.com/pytorch/pytorch/issues/138980
|
||||
graph = copy.deepcopy(graph)
|
||||
return compile_fx(graph, example_inputs, config_patches=current_config)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SplitItem:
|
||||
submod_name: str
|
||||
graph_id: int
|
||||
is_splitting_graph: bool
|
||||
graph: fx.GraphModule
|
||||
|
||||
|
||||
def split_graph(graph: fx.GraphModule,
|
||||
ops: List[str]) -> Tuple[fx.GraphModule, List[SplitItem]]:
|
||||
# split graph by ops
|
||||
subgraph_id = 0
|
||||
node_to_subgraph_id = {}
|
||||
split_op_graphs = []
|
||||
for node in graph.graph.nodes:
|
||||
if node.op in ("output", "placeholder"):
|
||||
continue
|
||||
if node.op == 'call_function' and str(node.target) in ops:
|
||||
subgraph_id += 1
|
||||
node_to_subgraph_id[node] = subgraph_id
|
||||
split_op_graphs.append(subgraph_id)
|
||||
subgraph_id += 1
|
||||
else:
|
||||
node_to_subgraph_id[node] = subgraph_id
|
||||
|
||||
# `keep_original_order` is important!
|
||||
# otherwise pytorch might reorder the nodes and
|
||||
# the semantics of the graph will change when we
|
||||
# have mutations in the graph
|
||||
split_gm = torch.fx.passes.split_module.split_module(
|
||||
graph,
|
||||
None,
|
||||
lambda node: node_to_subgraph_id[node],
|
||||
keep_original_order=True)
|
||||
|
||||
outputs = []
|
||||
|
||||
names = [name for (name, module) in split_gm.named_modules()]
|
||||
|
||||
for name in names:
|
||||
if "." in name or name == "":
|
||||
# recursive child module or the root module
|
||||
continue
|
||||
|
||||
module = getattr(split_gm, name)
|
||||
|
||||
graph_id = int(name.replace("submod_", ""))
|
||||
outputs.append(
|
||||
SplitItem(name, graph_id, (graph_id in split_op_graphs), module))
|
||||
|
||||
# sort by intetger graph_id, rather than string name
|
||||
outputs.sort(key=lambda x: x.graph_id)
|
||||
|
||||
return split_gm, outputs
|
||||
|
||||
|
||||
# we share the global graph pool among all the backends
|
||||
global_graph_pool = None
|
||||
|
||||
|
||||
class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
||||
"""Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
|
||||
It runs the given graph with fake inputs, and compile some
|
||||
submodules specified by `compile_submod_names` with the given
|
||||
compilation configs.
|
||||
|
||||
NOTE: the order in `compile_submod_names` matters, because
|
||||
it will be used to determine the order of the compiled piecewise
|
||||
graphs. The first graph will handle logging, and the last graph
|
||||
has some special cudagraph output handling.
|
||||
"""
|
||||
|
||||
def __init__(self, module: torch.fx.GraphModule,
|
||||
compile_submod_names: List[str],
|
||||
compilation_configs: CompilationConfig, graph_pool):
|
||||
super().__init__(module)
|
||||
from torch._guards import detect_fake_mode
|
||||
self.fake_mode = detect_fake_mode()
|
||||
self.compile_submod_names = compile_submod_names
|
||||
self.compilation_configs = compilation_configs
|
||||
self.graph_pool = graph_pool
|
||||
|
||||
def run(self, *args):
|
||||
fake_args = [
|
||||
self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
|
||||
for t in args
|
||||
]
|
||||
with self.fake_mode:
|
||||
return super().run(*fake_args)
|
||||
|
||||
def call_module(self, target: torch.fx.node.Target,
|
||||
args: Tuple[torch.fx.node.Argument,
|
||||
...], kwargs: Dict[str, Any]) -> Any:
|
||||
assert isinstance(target, str)
|
||||
output = super().call_module(target, args, kwargs)
|
||||
|
||||
if target in self.compile_submod_names:
|
||||
index = self.compile_submod_names.index(target)
|
||||
submod = self.fetch_attr(target)
|
||||
sym_shape_indices = [
|
||||
i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
|
||||
]
|
||||
compiled_graph_for_general_shape = wrap_inductor(
|
||||
submod,
|
||||
args,
|
||||
self.compilation_configs.inductor_compile_config,
|
||||
runtime_shape=None,
|
||||
do_logging=index == 0,
|
||||
use_inductor=self.compilation_configs.use_inductor)
|
||||
|
||||
self.module.__dict__[target] = PiecewiseBackend(
|
||||
submod, self.compilation_configs, self.graph_pool, index,
|
||||
len(self.compile_submod_names), sym_shape_indices,
|
||||
compiled_graph_for_general_shape)
|
||||
|
||||
compilation_counter.num_piecewise_capturable_graphs_seen += 1
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class VllmBackend:
|
||||
"""The compilation backend for `torch.compile` with VLLM.
|
||||
It is used for compilation level of `CompilationLevel.PIECEWISE`,
|
||||
where we customize the compilation.
|
||||
|
||||
The major work of this backend is to split the graph into
|
||||
piecewise graphs, and pass them to the piecewise backend.
|
||||
|
||||
This backend also handles custom passes and adds them to Inductor config.
|
||||
The order of the post-grad post-passes is:
|
||||
1. post_grad_passes (constructor parameter)
|
||||
2. config["post_grad_custom_post_pass"]
|
||||
3. fix_functionalization
|
||||
This way, all passes operate on a functionalized graph.
|
||||
"""
|
||||
|
||||
compilation_configs: CompilationConfig
|
||||
graph_pool: Any
|
||||
_called: bool = False
|
||||
# the graph we compiled
|
||||
graph: fx.GraphModule
|
||||
# the stiching graph module for all the piecewise graphs
|
||||
split_gm: fx.GraphModule
|
||||
piecewise_graphs: List[SplitItem]
|
||||
returned_callable: Callable
|
||||
# Inductor passes to run on the graph pre-defunctionalization
|
||||
post_grad_passes: Sequence[Callable]
|
||||
sym_tensor_indices: List[int]
|
||||
input_buffers: List[torch.Tensor]
|
||||
|
||||
def __init__(self, post_grad_passes: Sequence[Callable] = ()):
|
||||
global global_graph_pool
|
||||
if global_graph_pool is None:
|
||||
global_graph_pool = torch.cuda.graph_pool_handle()
|
||||
|
||||
# TODO: in the future, if we want to use multiple
|
||||
# streams, it might not be safe to share a global pool.
|
||||
# only investigate this when we use multiple streams
|
||||
self.graph_pool = global_graph_pool
|
||||
self.post_grad_passes = post_grad_passes
|
||||
|
||||
self.sym_tensor_indices = []
|
||||
self.input_buffers = []
|
||||
|
||||
# `torch.compile` is JIT compiled, so we don't need to
|
||||
# do anything here
|
||||
|
||||
def add_passes_to_config(self):
|
||||
config = self.compilation_configs
|
||||
passes = list(self.post_grad_passes)
|
||||
|
||||
passes = passes + [RedundantReshapesPass(config)]
|
||||
|
||||
if config.enable_fusion:
|
||||
passes = passes + [FusionPass.instance(config)]
|
||||
|
||||
inductor_config = config.inductor_compile_config
|
||||
if "post_grad_custom_post_pass" in inductor_config:
|
||||
passes = passes + [inductor_config["post_grad_custom_post_pass"]]
|
||||
|
||||
# add the fix_functionalization pass last, so that all other
|
||||
# passes operate on a functionalized graph
|
||||
passes = passes + [fix_functionalization]
|
||||
combined_pass = combine_fx_passes(passes)
|
||||
inductor_config["post_grad_custom_post_pass"] = combined_pass
|
||||
|
||||
def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
|
||||
|
||||
compilation_counter.num_graphs_seen += 1
|
||||
|
||||
# we control the compilation process, each instance can only be
|
||||
# called once
|
||||
assert not self._called, "VllmBackend can only be called once"
|
||||
|
||||
self.graph = graph
|
||||
# config is read now, because only here can
|
||||
# we get the sizes to capture for cudagraph
|
||||
# from compilation context
|
||||
self.compilation_configs = CompilationConfig.select_and_init_config()
|
||||
self.add_passes_to_config()
|
||||
|
||||
self.split_gm, self.piecewise_graphs = split_graph(
|
||||
graph, self.compilation_configs.non_cudagraph_ops)
|
||||
|
||||
from torch._dynamo.utils import lazy_format_graph_code
|
||||
logger.debug("%s", lazy_format_graph_code("before split", self.graph))
|
||||
logger.debug("%s", lazy_format_graph_code("after split",
|
||||
self.split_gm))
|
||||
|
||||
compilation_counter.num_piecewise_graphs_seen += len(
|
||||
self.piecewise_graphs)
|
||||
submod_names_to_compile = [
|
||||
item.submod_name for item in self.piecewise_graphs
|
||||
if not item.is_splitting_graph
|
||||
]
|
||||
|
||||
# propagate the split graph to the piecewise backend,
|
||||
# compile submodules with symbolic shapes
|
||||
PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile,
|
||||
self.compilation_configs,
|
||||
self.graph_pool).run(*example_inputs)
|
||||
|
||||
self._called = True
|
||||
|
||||
if not self.compilation_configs.use_cudagraph or \
|
||||
not self.compilation_configs.cudagraph_copy_inputs:
|
||||
return self.split_gm
|
||||
|
||||
# if we need to copy input buffers for cudagraph
|
||||
from torch._guards import detect_fake_mode
|
||||
fake_mode = detect_fake_mode()
|
||||
fake_args = [
|
||||
fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
|
||||
for t in example_inputs
|
||||
]
|
||||
|
||||
# index of tensors that have symbolic shapes (batch size)
|
||||
self.sym_tensor_indices = [
|
||||
i for i, x in enumerate(fake_args)
|
||||
if isinstance(x, torch._subclasses.fake_tensor.FakeTensor)
|
||||
]
|
||||
|
||||
# compiler managed cudagraph input buffers
|
||||
# we assume the first run with symbolic shapes
|
||||
# has the maximum size among all the tensors
|
||||
self.input_buffers = [
|
||||
example_inputs[x].clone() for x in self.sym_tensor_indices
|
||||
]
|
||||
|
||||
def copy_and_call(*args):
|
||||
list_args = list(args)
|
||||
for i, index in enumerate(self.sym_tensor_indices):
|
||||
runtime_tensor = list_args[index]
|
||||
runtime_shape = runtime_tensor.shape[0]
|
||||
static_tensor = self.input_buffers[i][:runtime_shape]
|
||||
|
||||
# copy the tensor to the static buffer
|
||||
static_tensor.copy_(runtime_tensor)
|
||||
|
||||
# replace the tensor in the list_args to the static buffer
|
||||
list_args[index] = static_tensor
|
||||
return self.split_gm(*list_args)
|
||||
|
||||
return copy_and_call
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ConcreteSizeEntry:
|
||||
runtime_shape: int
|
||||
need_to_compile: bool # the size is in compile_sizes
|
||||
use_cudagraph: bool # the size is in capture_sizes
|
||||
|
||||
compiled: bool = False
|
||||
runnable: Callable = None # type: ignore
|
||||
num_finished_warmup: int = 0
|
||||
cudagraph: Optional[torch.cuda.CUDAGraph] = None
|
||||
output: Optional[Any] = None
|
||||
|
||||
# for cudagraph debugging, track the input addresses
|
||||
# during capture, and check if they are the same during replay
|
||||
input_addresses: Optional[List[int]] = None
|
||||
|
||||
|
||||
class PiecewiseBackend:
|
||||
|
||||
def __init__(self, graph: fx.GraphModule,
|
||||
compilation_configs: CompilationConfig, graph_pool: Any,
|
||||
piecewise_compile_index: int, total_piecewise_compiles: int,
|
||||
sym_shape_indices: List[int],
|
||||
compiled_graph_for_general_shape: Callable):
|
||||
"""
|
||||
The backend for piecewise compilation.
|
||||
It mainly handles the compilation and cudagraph capturing.
|
||||
|
||||
We will compile `self.graph` once for the general shape,
|
||||
and then compile for different shapes specified in
|
||||
`compilation_configs.compile_sizes`.
|
||||
|
||||
Independently, we will capture cudagraph for different shapes.
|
||||
|
||||
If a shape needs both compilation and cudagraph, we will
|
||||
compile it first, and then capture cudagraph.
|
||||
"""
|
||||
self.graph = graph
|
||||
self.compilation_configs = compilation_configs
|
||||
self.graph_pool = graph_pool
|
||||
self.piecewise_compile_index = piecewise_compile_index
|
||||
self.total_piecewise_compiles = total_piecewise_compiles
|
||||
|
||||
self.is_first_graph = piecewise_compile_index == 0
|
||||
self.is_last_graph = (
|
||||
piecewise_compile_index == total_piecewise_compiles - 1)
|
||||
|
||||
self.compile_sizes: Set[int] = set(
|
||||
self.compilation_configs.compile_sizes)
|
||||
self.capture_sizes: Set[int] = set(
|
||||
self.compilation_configs.capture_sizes
|
||||
) if self.compilation_configs.use_cudagraph else set()
|
||||
|
||||
self.first_run_finished = False
|
||||
|
||||
self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa
|
||||
|
||||
self.sym_shape_indices = sym_shape_indices
|
||||
|
||||
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
|
||||
|
||||
# the entries for different shapes that we need to either
|
||||
# compile or capture cudagraph
|
||||
self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {}
|
||||
for shape in self.compile_sizes.union(self.capture_sizes):
|
||||
self.concrete_size_entries[shape] = ConcreteSizeEntry(
|
||||
runtime_shape=shape,
|
||||
need_to_compile=shape in self.compile_sizes,
|
||||
use_cudagraph=shape in self.capture_sizes,
|
||||
)
|
||||
|
||||
def __call__(self, *args) -> Any:
|
||||
if not self.first_run_finished:
|
||||
self.first_run_finished = True
|
||||
return self.compiled_graph_for_general_shape(*args)
|
||||
|
||||
runtime_shape = args[self.sym_shape_indices[0]]
|
||||
if runtime_shape not in self.concrete_size_entries:
|
||||
# we don't need to do anything for this shape
|
||||
return self.compiled_graph_for_general_shape(*args)
|
||||
|
||||
entry = self.concrete_size_entries[runtime_shape]
|
||||
|
||||
if entry.runnable is None:
|
||||
entry.runnable = self.compiled_graph_for_general_shape
|
||||
|
||||
if entry.need_to_compile and not entry.compiled:
|
||||
entry.compiled = True
|
||||
# args are real arguments
|
||||
entry.runnable = wrap_inductor(
|
||||
self.graph,
|
||||
args,
|
||||
self.compilation_configs.inductor_compile_config,
|
||||
runtime_shape=runtime_shape,
|
||||
do_logging=self.is_first_graph,
|
||||
use_inductor=self.compilation_configs.use_inductor)
|
||||
|
||||
if not entry.use_cudagraph:
|
||||
return entry.runnable(*args)
|
||||
|
||||
if entry.cudagraph is None:
|
||||
if entry.num_finished_warmup < self.compilation_configs.cudagraph_num_of_warmups: # noqa
|
||||
entry.num_finished_warmup += 1
|
||||
if self.is_first_graph:
|
||||
logger.debug(
|
||||
"Warming up %s/%s for shape %s",
|
||||
entry.num_finished_warmup,
|
||||
self.compilation_configs.cudagraph_num_of_warmups,
|
||||
runtime_shape)
|
||||
return entry.runnable(*args)
|
||||
|
||||
if self.is_first_graph:
|
||||
# Since we capture cudagraph for many different shapes and
|
||||
# capturing is fast, we don't need to log it for every shape.
|
||||
# We only log it in the debug mode.
|
||||
logger.debug("Capturing a cudagraph for shape %s",
|
||||
runtime_shape)
|
||||
|
||||
input_addresses = [
|
||||
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
||||
]
|
||||
entry.input_addresses = input_addresses
|
||||
cudagraph = torch.cuda.CUDAGraph()
|
||||
|
||||
with ExitStack() as stack:
|
||||
if not self.is_first_graph:
|
||||
# during every model forward, we will capture
|
||||
# many pieces of cudagraphs (roughly one per layer).
|
||||
# running gc again and again across layers will
|
||||
# make the cudagraph capture very slow.
|
||||
# therefore, we only run gc for the first graph,
|
||||
# and disable gc for the rest of the graphs.
|
||||
stack.enter_context(patch("gc.collect", lambda: None))
|
||||
stack.enter_context(
|
||||
patch("torch.cuda.empty_cache", lambda: None))
|
||||
|
||||
# mind-exploding: carefully manage the reference and memory.
|
||||
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
|
||||
# `output` is managed by pytorch's cudagraph pool
|
||||
output = entry.runnable(*args)
|
||||
if self.is_last_graph:
|
||||
# by converting it to weak ref,
|
||||
# the original `output` will immediately be released
|
||||
# to save memory. It is only safe to do this for
|
||||
# the last graph, because the output of the last graph
|
||||
# will not be used by any other cuda graph.
|
||||
output = weak_ref_tensors(output)
|
||||
|
||||
# here we always use weak ref for the output
|
||||
# to save memory
|
||||
entry.output = weak_ref_tensors(output)
|
||||
entry.cudagraph = cudagraph
|
||||
|
||||
compilation_counter.num_cudagraph_caputured += 1
|
||||
|
||||
# important: we need to return the output, rather than
|
||||
# the weak ref of the output, so that pytorch can correctly
|
||||
# manage the memory during cuda graph capture
|
||||
return output
|
||||
|
||||
if self.is_debugging_mode:
|
||||
# check if the input addresses are the same
|
||||
new_input_addresses = [
|
||||
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
||||
]
|
||||
assert new_input_addresses == entry.input_addresses, (
|
||||
"Input addresses for cudagraphs are different during replay."
|
||||
f" Expected {entry.input_addresses}, got {new_input_addresses}"
|
||||
)
|
||||
|
||||
entry.cudagraph.replay()
|
||||
return entry.output
|
||||
|
||||
|
||||
def select_default_backend(level: int) -> Union[str, Callable]:
|
||||
if level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]:
|
||||
backend_str = "eager"
|
||||
return backend_str
|
||||
assert level == CompilationLevel.PIECEWISE
|
||||
|
||||
return VllmBackend()
|
||||
23
vllm-v0.6.2/vllm/compilation/compile_context.py
Normal file
23
vllm-v0.6.2/vllm/compilation/compile_context.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
_compile_context: Any = None
|
||||
|
||||
|
||||
def get_compile_context() -> Any:
|
||||
"""Get the current compile context."""
|
||||
return _compile_context
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_compile_context(context: Any):
|
||||
"""A context manager that stores the current compile context,
|
||||
usually it is a list of sizes to specialize.
|
||||
"""
|
||||
global _compile_context
|
||||
prev_context = _compile_context
|
||||
_compile_context = context
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_compile_context = prev_context
|
||||
159
vllm-v0.6.2/vllm/compilation/config.py
Normal file
159
vllm-v0.6.2/vllm/compilation/config.py
Normal file
@@ -0,0 +1,159 @@
|
||||
import copy
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .compile_context import get_compile_context
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CompilationConfig(BaseModel):
|
||||
"""
|
||||
Configuration for compilation.
|
||||
It has two parts:
|
||||
- CudaGraph capture:
|
||||
- use_cudagraph: whether to use cudagraph inside compilation.
|
||||
- False: cudagraph inside compilation is not used.
|
||||
- True: cudagraph inside compilation is used. It requires
|
||||
that all input buffers have fixed addresses.
|
||||
Note that this is orthogonal to the cudagraph capture out
|
||||
side of compilation.
|
||||
TODO: move outside cudagraph logic into compilation.
|
||||
torch.compile will handle cudagraph capture logic in the future.
|
||||
- cudagraph_capture_sizes: sizes to capture cudagraph.
|
||||
- None: capture sizes are inferred from compilation context.
|
||||
- List[int]: capture sizes are specified.
|
||||
- cudagraph_num_of_warmups: number of warmup runs for cudagraph.
|
||||
It means the first several runs will be treated as warmup runs.
|
||||
Only after that, the execution will be recorded, and the recorded
|
||||
cudagraph will be used for subsequent runs.
|
||||
- cudagraph_copy_inputs: whether to copy input tensors for
|
||||
cudagraph. If the caller can guarantee that the same input buffers
|
||||
are always used, it can set this to False. Otherwise, it should
|
||||
set this to True, and the compiler will copy the input to an
|
||||
internally managed buffer. Default is False.
|
||||
- Inductor compilation:
|
||||
- use_inductor: whether to use inductor compilation.
|
||||
- False: inductor compilation is not used. graph runs in eager.
|
||||
- True: inductor compilation is used. one graph for symbolic shape
|
||||
is compiled. In addition, compile for different sizes specified
|
||||
in inductor_compile_sizes, using configurations
|
||||
in inductor_compile_config.
|
||||
- inductor_compile_sizes: sizes to compile for inductor.
|
||||
- inductor_specialize_for_cudagraph_no_more_than: an optional integer
|
||||
to specialize inductor for cudagraph sizes no more than the
|
||||
specified size. It is useful when we want to specialize inductor
|
||||
with a subset of cudagraph sizes.
|
||||
- inductor_compile_config: additional configurations for inductor.
|
||||
- None: use default configurations.
|
||||
- inductor_passes: additional passes for inductor. It is a dictionary
|
||||
from pass name to pass function qualified name. We use function
|
||||
name because the config uses json format. If we pass the config
|
||||
from Python, functions can also be passed directly via Python object
|
||||
constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`
|
||||
- Custom inductor passes:
|
||||
- dump_graph_stages: list of stages for which we want to dump the graph.
|
||||
Each pass defines its own stages (before, after, maybe in-between).
|
||||
- dump_graph_dir: directory to dump the graph. Default is .
|
||||
- enable_fusion: whether to enable the custom fusion pass.
|
||||
TODO better pass enabling system.
|
||||
|
||||
Why we have different sizes for cudagraph and inductor:
|
||||
- cudagraph: a cudagraph captured for a specific size can only be used
|
||||
for the same size. We need to capture all the sizes we want to use.
|
||||
- inductor: a graph compiled by inductor for a general shape can be used
|
||||
for different sizes. Inductor can also compile for specific sizes,
|
||||
where it can have more information to optimize the graph with fully
|
||||
static shapes. However, we find the general shape compilation is
|
||||
sufficient for most cases. It might be beneficial to compile for
|
||||
certain small batchsizes, where inductor is good at optimizing.
|
||||
"""
|
||||
use_inductor: bool = True
|
||||
inductor_specialize_for_cudagraph_no_more_than: Optional[int] = None
|
||||
inductor_compile_sizes: Optional[List[int]] = Field(default_factory=dict)
|
||||
inductor_compile_config: Dict = Field(default_factory=dict)
|
||||
inductor_passes: Dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
use_cudagraph: bool = False
|
||||
non_cudagraph_ops: List[str] = Field(default_factory=list)
|
||||
cudagraph_num_of_warmups: int = 0
|
||||
cudagraph_capture_sizes: Optional[List[int]] = None
|
||||
cudagraph_copy_inputs: bool = False
|
||||
|
||||
dump_graph_stages: List[str] = Field(default_factory=list)
|
||||
dump_graph_dir: Path = Field(default=Path("."))
|
||||
enable_fusion: bool = True
|
||||
|
||||
# not configurable, computed after init
|
||||
compile_sizes: List[int] = PrivateAttr
|
||||
capture_sizes: List[int] = PrivateAttr
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
for k, v in self.inductor_passes.items():
|
||||
if not isinstance(v, str):
|
||||
assert callable(v), (
|
||||
f"pass {k} should be a function or a qualified name")
|
||||
self.inductor_compile_config[k] = v
|
||||
continue
|
||||
|
||||
# resolve function from qualified name
|
||||
names = v.split(".")
|
||||
module = ".".join(names[:-1])
|
||||
func_name = names[-1]
|
||||
func = __import__(module).__dict__[func_name]
|
||||
self.inductor_compile_config[k] = func
|
||||
|
||||
def init_during_runtime(self):
|
||||
"""To complete the initialization of config,
|
||||
we need to know the compile context, which is only available
|
||||
during the first run of the model.
|
||||
"""
|
||||
context = get_compile_context()
|
||||
context = copy.deepcopy(context) if context is not None else []
|
||||
sizes_to_specialize: List[int] = context
|
||||
if self.cudagraph_capture_sizes is None:
|
||||
self.capture_sizes = sizes_to_specialize
|
||||
else:
|
||||
self.capture_sizes = self.cudagraph_capture_sizes
|
||||
logger.info(("cudagraph sizes specified by model runner"
|
||||
" %s is overridden by config %s"),
|
||||
sizes_to_specialize, self.cudagraph_capture_sizes)
|
||||
if self.inductor_specialize_for_cudagraph_no_more_than is not None:
|
||||
assert self.inductor_compile_sizes is None, (
|
||||
"inductor_compile_sizes should be None when "
|
||||
"inductor_specialize_for_cudagraph_no_more_than is not None")
|
||||
self.compile_sizes = [
|
||||
x for x in self.capture_sizes
|
||||
if x <= self.inductor_specialize_for_cudagraph_no_more_than
|
||||
]
|
||||
else:
|
||||
assert self.inductor_compile_sizes is not None, (
|
||||
"inductor_compile_sizes should not be None when "
|
||||
"inductor_specialize_for_cudagraph_no_more_than is None")
|
||||
self.compile_sizes = self.inductor_compile_sizes
|
||||
|
||||
@staticmethod
|
||||
def select_and_init_config() -> "CompilationConfig":
|
||||
"""The order of selecting config is:
|
||||
1. Use the config specified in environment variable.
|
||||
2. Use the config specified in plugins.
|
||||
3. Use the default config.
|
||||
"""
|
||||
config_path = envs.VLLM_TORCH_COMPILE_CONFIG
|
||||
if config_path is not None:
|
||||
with open(config_path) as json_file:
|
||||
config = CompilationConfig.model_validate_json(
|
||||
json_file.read())
|
||||
else:
|
||||
from vllm.plugins import get_compilation_config
|
||||
predefined_config = get_compilation_config()
|
||||
config = predefined_config if predefined_config is not None else (
|
||||
CompilationConfig())
|
||||
|
||||
config.init_during_runtime()
|
||||
return config
|
||||
30
vllm-v0.6.2/vllm/compilation/counter.py
Normal file
30
vllm-v0.6.2/vllm/compilation/counter.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import copy
|
||||
import dataclasses
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CompilationCounter:
|
||||
num_graphs_seen: int = 0
|
||||
# including the splitting ops
|
||||
num_piecewise_graphs_seen: int = 0
|
||||
# not including the splitting ops
|
||||
num_piecewise_capturable_graphs_seen: int = 0
|
||||
num_inductor_compilations: int = 0
|
||||
num_cudagraph_caputured: int = 0
|
||||
|
||||
def clone(self) -> "CompilationCounter":
|
||||
return copy.deepcopy(self)
|
||||
|
||||
@contextmanager
|
||||
def expect(self, **kwargs):
|
||||
old = self.clone()
|
||||
yield
|
||||
for k, v in kwargs.items():
|
||||
assert getattr(self, k) - getattr(old, k) == v, (
|
||||
f"{k} not as expected, before it is {getattr(old, k)}"
|
||||
f", after it is {getattr(self, k)}, "
|
||||
f"expected diff is {v}")
|
||||
|
||||
|
||||
compilation_counter = CompilationCounter()
|
||||
182
vllm-v0.6.2/vllm/compilation/decorators.py
Normal file
182
vllm-v0.6.2/vllm/compilation/decorators.py
Normal file
@@ -0,0 +1,182 @@
|
||||
import inspect
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.levels import CompilationLevel
|
||||
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import supports_dynamo
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def support_torch_compile(
|
||||
cls: Optional[type] = None,
|
||||
dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]] = None):
|
||||
"""
|
||||
A decorator to add support for compiling the forward method of a class.
|
||||
|
||||
Usage 1: use directly as a decorator without arguments:
|
||||
|
||||
```python
|
||||
@support_torch_compile
|
||||
class MyModel(nn.Module):
|
||||
def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
|
||||
...
|
||||
```
|
||||
|
||||
Usage 2: use as a decorator with arguments:
|
||||
|
||||
```python
|
||||
@support_torch_compile(dynamic_arg_dims={"x": 0, "y": 0})
|
||||
class MyModel(nn.Module):
|
||||
def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
|
||||
...
|
||||
```
|
||||
|
||||
`dynamic_arg_dims` is a dictionary that maps argument names to the dynamic
|
||||
dimensions of the argument. The dynamic dimensions can be either a single
|
||||
integer or a list of integers.
|
||||
|
||||
if `dynamic_arg_dims` is `None`, it is inferred from the type annotation
|
||||
of the `forward` method, based on the following default rules:
|
||||
|
||||
- if the argument is annotated as `torch.Tensor` or
|
||||
`Optional[torch.Tensor]`, the first dimension will be
|
||||
marked as dynamic.
|
||||
- if the argument is annotated as `IntermediateTensors`, the first
|
||||
dimension of all the tensors in the intermediate tensors
|
||||
will be marked as dynamic.
|
||||
|
||||
During runtime, when we actually mark dimensions of tensors,
|
||||
it depends on the value of arguments:
|
||||
|
||||
- if it is a single integer, the corresponding dimension of the argument
|
||||
will be marked as dynamic.
|
||||
- if it is `None`, ignored.
|
||||
- if it is `IntermediateTensors`, all the tensors in the intermediate
|
||||
tensors will be marked as dynamic.
|
||||
- otherwise, it will raise an error.
|
||||
|
||||
NOTE: if an argument is `None`, it should always be passed as `None` during
|
||||
the lifetime of the model, otherwise, it cannot be captured as a single
|
||||
computation graph.
|
||||
"""
|
||||
|
||||
def cls_decorator_helper(cls: type):
|
||||
# helper to pass `dynamic_arg_dims`` to `_support_torch_compile``
|
||||
# to avoid too much indentation for `_support_torch_compile``
|
||||
if not hasattr(cls, 'forward'):
|
||||
raise TypeError("decorated class should have a forward method.")
|
||||
sig = inspect.signature(cls.forward)
|
||||
inferred_dynamic_arg_dims = dynamic_arg_dims
|
||||
if inferred_dynamic_arg_dims is None:
|
||||
inferred_dynamic_arg_dims = {}
|
||||
for k, v in sig.parameters.items():
|
||||
if v.annotation in [
|
||||
torch.Tensor, Optional[torch.Tensor],
|
||||
IntermediateTensors, Optional[IntermediateTensors]
|
||||
]:
|
||||
inferred_dynamic_arg_dims[k] = 0
|
||||
|
||||
logger.debug(("Inferred dynamic dimensions for "
|
||||
"forward method of %s: %s"), cls,
|
||||
list(inferred_dynamic_arg_dims.keys()))
|
||||
|
||||
if len(inferred_dynamic_arg_dims) == 0:
|
||||
raise ValueError(
|
||||
"No dynamic dimensions found in the forward method of "
|
||||
f"{cls}. Please provide dynamic_arg_dims explicitly.")
|
||||
|
||||
for k in inferred_dynamic_arg_dims:
|
||||
if k not in sig.parameters:
|
||||
raise ValueError(
|
||||
f"Argument {k} not found in the forward method of {cls}")
|
||||
return _support_torch_compile(cls, inferred_dynamic_arg_dims)
|
||||
|
||||
if cls is not None:
|
||||
# use `support_torch_compile` as a decorator without arguments
|
||||
assert isinstance(cls, type)
|
||||
return cls_decorator_helper(cls)
|
||||
|
||||
return cls_decorator_helper
|
||||
|
||||
|
||||
def _support_torch_compile(cls: type,
|
||||
dynamic_arg_dims: Dict[str, Union[int, List[int]]]):
|
||||
"""
|
||||
A decorator to add support for compiling the forward method of a class.
|
||||
"""
|
||||
if TorchCompileWrapperWithCustomDispatcher in cls.__bases__:
|
||||
# support decorating multiple times
|
||||
return cls
|
||||
|
||||
# take care of method resolution order
|
||||
# make sure super().__init__ is called on the base class
|
||||
# other than TorchCompileWrapperWithCustomDispatcher
|
||||
cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, )
|
||||
|
||||
old_init = cls.__init__ # type: ignore
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
|
||||
old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
|
||||
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
|
||||
# will handle the compilation, so we don't need to do anything here.
|
||||
self.do_not_compile = envs.VLLM_TORCH_COMPILE_LEVEL in [
|
||||
CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS
|
||||
] or not supports_dynamo()
|
||||
if self.do_not_compile:
|
||||
return
|
||||
TorchCompileWrapperWithCustomDispatcher.__init__(self)
|
||||
|
||||
cls.__init__ = __init__ # type: ignore
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# torch.compiler.is_compiling() means we are inside the compilation
|
||||
# e.g. TPU has the compilation logic in model runner, so we don't
|
||||
# need to compile the model inside.
|
||||
if self.do_not_compile or torch.compiler.is_compiling():
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
# the first compilation needs to have dynamic shapes marked
|
||||
if len(self.compiled_codes) < 1:
|
||||
sig = inspect.signature(self.__class__.forward)
|
||||
bound_args = sig.bind(self, *args, **kwargs)
|
||||
bound_args.apply_defaults()
|
||||
for k, dims in dynamic_arg_dims.items():
|
||||
arg = bound_args.arguments.get(k)
|
||||
if arg is not None:
|
||||
if isinstance(arg, torch.Tensor):
|
||||
torch._dynamo.mark_dynamic(arg, dims)
|
||||
elif isinstance(arg, IntermediateTensors):
|
||||
for tensor in arg.tensors.values():
|
||||
torch._dynamo.mark_dynamic(tensor, dims)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported dynamic dimensions"
|
||||
f" {dims} for argument {k} with type {type(arg)}.")
|
||||
|
||||
# if we don't use custom dispatcher, we can directly call the
|
||||
# compiled function and let torch.compile handle the dispatching,
|
||||
# with the overhead of guard evaluation and recompilation.
|
||||
if len(self.compiled_codes) < 1 or not self.use_custom_dispatcher:
|
||||
# it seems Dynamo reuse the compilation across instances,
|
||||
# while we need to make sure the compiled code is not reused.
|
||||
# we need to control all the compilation of the model.
|
||||
torch._dynamo.eval_frame.remove_from_cache(
|
||||
self.original_code_object)
|
||||
return self.compiled_callable(*args, **kwargs)
|
||||
|
||||
# usually, capturing the model once is enough, and then we can
|
||||
# dispatch to the compiled code directly, without going through
|
||||
# the Dynamo guard mechanism.
|
||||
with self.dispatch_to_code(0):
|
||||
model_output = self.forward(*args, **kwargs)
|
||||
return model_output
|
||||
|
||||
cls.__call__ = __call__ # type: ignore
|
||||
return cls
|
||||
291
vllm-v0.6.2/vllm/compilation/fusion.py
Normal file
291
vllm-v0.6.2/vllm/compilation/fusion.py
Normal file
@@ -0,0 +1,291 @@
|
||||
import operator
|
||||
from typing import Iterable, List, Optional
|
||||
|
||||
import torch
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._inductor.pattern_matcher import (Match, PatternMatcherPass,
|
||||
fwd_only, register_replacement)
|
||||
|
||||
from vllm.compilation.config import CompilationConfig
|
||||
from vllm.compilation.inductor_pass import InductorPass
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def rms_pattern_static(result: torch.Tensor, result_rms: torch.Tensor,
|
||||
input: torch.Tensor, weight: torch.Tensor,
|
||||
scale: torch.Tensor):
|
||||
at1 = auto_functionalized(torch.ops._C.rms_norm.default,
|
||||
result=result_rms,
|
||||
input=input,
|
||||
weight=weight,
|
||||
epsilon=1e-5)
|
||||
at2 = auto_functionalized(torch.ops._C.static_scaled_fp8_quant.default,
|
||||
result=result,
|
||||
input=at1[1],
|
||||
scale=scale)
|
||||
|
||||
# result
|
||||
return at2[1]
|
||||
|
||||
|
||||
def rms_replacement_static(result: torch.Tensor, result_rms: torch.Tensor,
|
||||
input: torch.Tensor, weight: torch.Tensor,
|
||||
scale: torch.Tensor):
|
||||
at = auto_functionalized(torch.ops._C.rms_norm_static_fp8_quant.default,
|
||||
result=result,
|
||||
input=input,
|
||||
weight=weight,
|
||||
scale=scale,
|
||||
epsilon=1e-5)
|
||||
|
||||
# result
|
||||
return at[1]
|
||||
|
||||
|
||||
def rms_pattern_residual_static(result: torch.Tensor, input: torch.Tensor,
|
||||
residual: torch.Tensor, weight: torch.Tensor,
|
||||
scale: torch.Tensor):
|
||||
at = auto_functionalized(torch.ops._C.fused_add_rms_norm.default,
|
||||
input=input,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
epsilon=1e-5)
|
||||
at1 = auto_functionalized(torch.ops._C.static_scaled_fp8_quant.default,
|
||||
result=result,
|
||||
input=at[1],
|
||||
scale=scale)
|
||||
|
||||
# result, residual
|
||||
return at1[1], at[2]
|
||||
|
||||
|
||||
def rms_replacement_residual_static(result: torch.Tensor, input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor, scale: torch.Tensor):
|
||||
at = auto_functionalized(
|
||||
torch.ops._C.fused_add_rms_norm_static_fp8_quant.default,
|
||||
result=result,
|
||||
input=input,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
scale=scale,
|
||||
epsilon=1e-5)
|
||||
# result, residual
|
||||
return at[1], at[2]
|
||||
|
||||
|
||||
def empty_bf16(*args, **kwargs):
|
||||
return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
|
||||
def empty_fp8(*args, **kwargs):
|
||||
fp8 = torch.float8_e4m3fn
|
||||
return torch.empty(*args, **kwargs, dtype=fp8, device="cuda")
|
||||
|
||||
|
||||
def empty_fp32(*args, **kwargs):
|
||||
return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")
|
||||
|
||||
|
||||
# Utilities for post-processing multi-output matches
|
||||
def is_func(node: torch.fx.Node, target) -> bool:
|
||||
return node.op == "call_function" and node.target == target
|
||||
|
||||
|
||||
# Returns the first auto_functionalized node with the given op (if it exists)
|
||||
def find_auto_fn_maybe(nodes: Iterable[torch.fx.Node],
|
||||
op) -> Optional[torch.fx.Node]:
|
||||
for node in nodes:
|
||||
if is_func(node, auto_functionalized) and node.args[0] == op: # noqa
|
||||
return node
|
||||
return None
|
||||
|
||||
|
||||
# Returns the first auto_functionalized node with the given op
|
||||
def find_auto_fn(nodes: Iterable[torch.fx.Node], op) -> torch.fx.Node:
|
||||
node = find_auto_fn_maybe(nodes, op)
|
||||
assert node is not None, f"Could not find {op} in nodes {nodes}"
|
||||
return node
|
||||
|
||||
|
||||
# Returns the getitem node that extracts the idx-th element from node
|
||||
# (if it exists)
|
||||
def find_getitem_maybe(node: torch.fx.Node,
|
||||
idx: int) -> Optional[torch.fx.Node]:
|
||||
for user in node.users:
|
||||
if is_func(user, operator.getitem) and user.args[1] == idx:
|
||||
return user
|
||||
return None
|
||||
|
||||
|
||||
# Returns the getitem node that extracts the idx-th element from node
|
||||
def find_getitem(node: torch.fx.Node, idx: int) -> torch.fx.Node:
|
||||
ret = find_getitem_maybe(node, idx)
|
||||
assert ret is not None, f"Could not find getitem {idx} in node {node}"
|
||||
return ret
|
||||
|
||||
|
||||
class FusionPass(InductorPass):
|
||||
"""
|
||||
This pass fuses a pre-defined set of custom ops into fused ops.
|
||||
It uses the torch pattern matcher to find the patterns and replace them.
|
||||
It also manually processes multi-output matches, as those are broken in
|
||||
the torch pattern matcher.
|
||||
|
||||
Because patterns can only be registered once, the pass is a singleton.
|
||||
This will be addressed in a future version of PyTorch:
|
||||
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
|
||||
"""
|
||||
|
||||
_instance: 'Optional[FusionPass]' = None
|
||||
|
||||
@classmethod
|
||||
def instance(cls, config: CompilationConfig):
|
||||
"""
|
||||
Get the singleton instance of the FusionPass.
|
||||
If the instance exists, the config is updated but
|
||||
initialization is not repeated.
|
||||
"""
|
||||
if cls._instance is None:
|
||||
cls._instance = FusionPass(config)
|
||||
else:
|
||||
cls._instance.config = config
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, config: CompilationConfig):
|
||||
assert self.__class__._instance is None, \
|
||||
"FusionPass singleton instance already exists"
|
||||
super().__init__(config)
|
||||
|
||||
self.matches: List[Match] = []
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="fusion_pass")
|
||||
|
||||
# Fuse rms_norm + static_scaled_fp8_quant into
|
||||
# rms_norm_static_fp8_quant
|
||||
inputs = [
|
||||
empty_fp8(5, 4),
|
||||
empty_bf16(5, 4),
|
||||
empty_bf16(5, 4),
|
||||
empty_bf16(1, 5),
|
||||
empty_fp32(1, 1)
|
||||
]
|
||||
register_replacement(rms_pattern_static, rms_replacement_static,
|
||||
inputs, fwd_only, self.patterns)
|
||||
|
||||
# Fuse fused_add_rms_norm + static_scaled_fp8_quant into
|
||||
# fused_add_rms_norm_static_fp8_quant
|
||||
# Because pattern has 2 outputs, we need to manually process the match
|
||||
# (see process_matches)
|
||||
inputs = [
|
||||
empty_fp8(5, 4),
|
||||
empty_bf16(5, 4),
|
||||
empty_bf16(5, 4),
|
||||
empty_bf16(1, 5),
|
||||
empty_fp32(1, 1)
|
||||
]
|
||||
register_replacement(rms_pattern_residual_static,
|
||||
rms_replacement_residual_static,
|
||||
inputs,
|
||||
fwd_only,
|
||||
self.patterns,
|
||||
extra_check=lambda m: self.record_match(m))
|
||||
|
||||
def record_match(self, match: Match) -> bool:
|
||||
# Hijack the extra_check to record the match and
|
||||
# save it for post-processing.
|
||||
self.matches.append(match)
|
||||
|
||||
# Return False to prevent automatic replacement.
|
||||
return False
|
||||
|
||||
def process_matches(self, graph: torch.fx.Graph):
|
||||
"""
|
||||
Manually process multi-output matches and replace them with fused nodes.
|
||||
This is necessary because the automatic replacement for multi-output
|
||||
matches is broken: https://github.com/pytorch/pytorch/issues/137280
|
||||
"""
|
||||
for match in self.matches:
|
||||
# To avoid use-before-definition errors, insert replacement nodes
|
||||
# after the last node in the match.
|
||||
# match.nodes is not guaranteed to be sorted.
|
||||
# Find the last node in the match.
|
||||
for last_node_in_match in reversed(graph.nodes):
|
||||
if last_node_in_match in match.nodes:
|
||||
break
|
||||
else:
|
||||
raise ValueError("No nodes in graph")
|
||||
|
||||
# Insert a new auto_functionalized node for the fused operation,
|
||||
# as well as getitem nodes to extract the result and residual.
|
||||
# The auto_functionalized node returns a tuple of
|
||||
# (None, result, residual) - None is the function return value.
|
||||
# The resulting graph looks like this:
|
||||
# at = auto_functionalized(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, ...) # noqa
|
||||
# result_node_new = at[1]
|
||||
# residual_node_new = at[2]
|
||||
with graph.inserting_after(last_node_in_match):
|
||||
kwargs = match.kwargs
|
||||
kwargs["epsilon"] = 1e-5 # Currently hard-coded in RMSNorm
|
||||
|
||||
fused_node = graph.call_function(
|
||||
auto_functionalized,
|
||||
(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default,
|
||||
),
|
||||
kwargs=kwargs)
|
||||
|
||||
graph.inserting_after(fused_node)
|
||||
result_node_new = graph.call_function(operator.getitem,
|
||||
(fused_node, 1))
|
||||
residual_node_new = graph.call_function(
|
||||
operator.getitem, (fused_node, 2))
|
||||
|
||||
# Last part of replacement is rebinding the users of nodes in the
|
||||
# match to use the new nodes.
|
||||
|
||||
# Find the nodes in the match that we need to rebind
|
||||
rms_node = find_auto_fn(match.nodes,
|
||||
torch.ops._C.fused_add_rms_norm.default)
|
||||
quant_node = find_auto_fn(
|
||||
match.nodes, torch.ops._C.static_scaled_fp8_quant.default)
|
||||
|
||||
assert len(rms_node.users) == 2
|
||||
assert len(quant_node.users) == 1
|
||||
|
||||
# meta["val"] is used by de-functionalization and has to contain the
|
||||
# value of the node (tuple of tensors) that would be returned by the
|
||||
# functionalized node during tracing.
|
||||
|
||||
rms_tup = rms_node.meta["val"]
|
||||
quant_tup = quant_node.meta["val"]
|
||||
|
||||
# The result of fused_node must be a tuple with the first element
|
||||
# None (the function return value) and the remaining elements
|
||||
# representing the mutated inputs.
|
||||
fused_tup = (None, quant_tup[1], rms_tup[1], rms_tup[2])
|
||||
fused_node.meta["val"] = fused_tup
|
||||
|
||||
# Find the getitem nodes and replace their uses with the new nodes.
|
||||
# The old nodes will be removed by DCE at the end of the pass.
|
||||
find_getitem(rms_node, 2).replace_all_uses_with(residual_node_new)
|
||||
find_getitem(quant_node, 1).replace_all_uses_with(result_node_new)
|
||||
|
||||
# Finally, remove matched nodes
|
||||
graph.eliminate_dead_code()
|
||||
assert all(node not in graph.nodes for match in self.matches
|
||||
for node in match.nodes)
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
self.dump_graph(graph, "before_fusion")
|
||||
|
||||
count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", count)
|
||||
self.dump_graph(graph, "after_pattern_match")
|
||||
|
||||
# Manually process multi-output matches (and run DCE)
|
||||
self.process_matches(graph)
|
||||
logger.debug("Post-processed %s matches", len(self.matches))
|
||||
self.dump_graph(graph, "after_fusion")
|
||||
self.matches.clear()
|
||||
38
vllm-v0.6.2/vllm/compilation/inductor_pass.py
Normal file
38
vllm-v0.6.2/vllm/compilation/inductor_pass.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.compilation.config import CompilationConfig
|
||||
# yapf: disable
|
||||
from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_world_size as get_tp_world_size)
|
||||
from vllm.distributed import model_parallel_is_initialized as p_is_init
|
||||
# yapf: enable
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class InductorPass(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
raise NotImplementedError
|
||||
|
||||
def __init__(self, config: CompilationConfig):
|
||||
self.config = config
|
||||
|
||||
def dump_graph(self, graph: torch.fx.Graph, stage: str):
|
||||
if stage in self.config.dump_graph_stages:
|
||||
# Make sure filename includes rank in the distributed setting
|
||||
parallel = p_is_init() and get_tp_world_size() > 1
|
||||
rank = f"-{get_tp_rank()}" if parallel else ""
|
||||
filepath = self.config.dump_graph_dir / f"{stage}{rank}.py"
|
||||
|
||||
logger.info("Printing graph to %s", filepath)
|
||||
with open(filepath, "w") as f:
|
||||
src = graph.python_code(root_module="self", verbose=True).src
|
||||
# Add imports so it's not full of errors
|
||||
print("import torch; from torch import device", file=f)
|
||||
print(src, file=f)
|
||||
8
vllm-v0.6.2/vllm/compilation/levels.py
Normal file
8
vllm-v0.6.2/vllm/compilation/levels.py
Normal file
@@ -0,0 +1,8 @@
|
||||
# constants for the levels of the compilation process
|
||||
|
||||
|
||||
class CompilationLevel:
|
||||
NO_COMPILATION = 0
|
||||
DYNAMO_AS_IS = 1
|
||||
DYNAMO_ONCE = 2
|
||||
PIECEWISE = 3
|
||||
85
vllm-v0.6.2/vllm/compilation/reshapes.py
Normal file
85
vllm-v0.6.2/vllm/compilation/reshapes.py
Normal file
@@ -0,0 +1,85 @@
|
||||
from typing import Union
|
||||
|
||||
import torch.fx
|
||||
from torch import SymInt
|
||||
|
||||
from vllm.compilation.fusion import is_func
|
||||
from vllm.compilation.inductor_pass import InductorPass
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class RedundantReshapesPass(InductorPass):
|
||||
"""
|
||||
This is an inductor pass that removes redundant reshape operations.
|
||||
It is required for RMSNorm-quant fusion to work properly.
|
||||
That's because apply_fp8_linear adds a reshape, which is redundant
|
||||
in the 2D-case.
|
||||
|
||||
Example graph:
|
||||
|
||||
getitem_1: "f16[s0, 4096]" = ...
|
||||
view_1: "f16[s0, 4096]" = torch.reshape(getitem_1, [-1, 4096])
|
||||
at = auto_functionalized(static_scaled_fp8_quant, input = view_1, ...)
|
||||
out: "f8e4m3fn[s0, 4096]" = at[1]
|
||||
|
||||
Can be replaced with:
|
||||
getitem_1: "f16[s0, 4096]" = ...
|
||||
at = auto_functionalized(static_scaled_fp8_quant, input = getitem_1, ...)
|
||||
out: "f8e4m3fn[s0, 4096]" = at[1]
|
||||
"""
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
self.dump_graph(graph, "before_reshapes")
|
||||
count = 0
|
||||
# Remove no-op reshapes/views:
|
||||
for node in graph.nodes:
|
||||
if is_func(node, torch.ops.aten.reshape.default):
|
||||
input, shape = node.args[:2]
|
||||
input_shape = input.meta["val"].shape
|
||||
if len(shape) != len(input_shape):
|
||||
# Reshape changing rank, skip
|
||||
continue
|
||||
|
||||
if shape.count(-1) > 1:
|
||||
# Invalid reshape args, skip
|
||||
continue
|
||||
|
||||
if all(
|
||||
self.dims_equivalent(s, i_s)
|
||||
for s, i_s in zip(shape, input_shape)):
|
||||
node.replace_all_uses_with(input)
|
||||
graph.erase_node(node)
|
||||
count += 1
|
||||
|
||||
logger.debug("Removed %s no-op reshapes", count)
|
||||
|
||||
self.dump_graph(graph, "after_reshapes")
|
||||
|
||||
def dims_equivalent(self, dim: Union[int, torch.fx.Node],
|
||||
i_dim: Union[int, SymInt]) -> bool:
|
||||
"""
|
||||
This function checks if two dimensions are equivalent.
|
||||
:param dim: The dimension arg to reshape
|
||||
:param i_dim: The corresponding dimension in the input tensor
|
||||
:return: Are the dimensions equivalent?
|
||||
|
||||
There are three cases in which the dimensions are equivalent:
|
||||
1. The dimensions are equal (both integers)
|
||||
2. The reshape dimension is -1 (i.e. inferred)
|
||||
3. The dimensions both correspond to the same SymInt
|
||||
|
||||
While case 2 does not guarantee the dimensions are equal,
|
||||
they are equal if all other dimensions are equal.
|
||||
|
||||
In case 3, the reshape dimension is a torch.fx.Node,
|
||||
and its value is a SymInt. That value is equal to the
|
||||
input dimension.
|
||||
|
||||
"""
|
||||
# Case 1 and 2
|
||||
if dim == i_dim or dim == -1:
|
||||
return True
|
||||
# Case 3
|
||||
return isinstance(dim, torch.fx.Node) and dim.meta["val"] == i_dim
|
||||
102
vllm-v0.6.2/vllm/compilation/wrapper.py
Normal file
102
vllm-v0.6.2/vllm/compilation/wrapper.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import os
|
||||
import sys
|
||||
from abc import abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from types import CodeType
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
|
||||
from .levels import CompilationLevel
|
||||
|
||||
|
||||
class TorchCompileWrapperWithCustomDispatcher:
|
||||
"""
|
||||
A wrapper class for torch.compile, with a custom dispatch logic.
|
||||
Subclasses should:
|
||||
1. Implement the forward method
|
||||
2. Implement the dispatch logic in the __call__ method
|
||||
It can use `self.compiled_codes` to access the compiled bytecode,
|
||||
and `with self.dispatch_to_code(index):` to dispatch to
|
||||
the compiled code.
|
||||
3. Implement the `__init__` method to determine how to call
|
||||
`torch.compile` over the forward method.
|
||||
"""
|
||||
|
||||
def __init__(self, compiled_callable: Optional[Callable] = None):
|
||||
|
||||
if compiled_callable is None:
|
||||
# default compilation settings
|
||||
# compiling the forward method
|
||||
|
||||
# choose the compile backend
|
||||
|
||||
# if the user has set the backend, use it
|
||||
from vllm.plugins import get_torch_compile_backend
|
||||
backend = get_torch_compile_backend()
|
||||
if backend is None:
|
||||
from vllm.compilation.backends import select_default_backend
|
||||
backend = select_default_backend(envs.VLLM_TORCH_COMPILE_LEVEL)
|
||||
|
||||
compiled_callable = torch.compile(
|
||||
self.forward,
|
||||
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
||||
backend=backend)
|
||||
|
||||
self.compiled_callable = compiled_callable
|
||||
self.original_code_object = self.__class__.forward.__code__
|
||||
self.compiled_codes: List[CodeType] = []
|
||||
torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)
|
||||
|
||||
# read the env var to determine whether to use the custom dispatcher
|
||||
# subclasses can use this to switch between the custom dispatcher
|
||||
# and the default Dynamo guard mechanism.
|
||||
self.use_custom_dispatcher: bool = \
|
||||
envs.VLLM_TORCH_COMPILE_LEVEL >= CompilationLevel.DYNAMO_ONCE
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""Implement the dispatch logic here, beyond the torch.compile level.
|
||||
NOTE: this function can have additional arguments beyond the forward
|
||||
method, for directly dispatching to the compiled code.
|
||||
"""
|
||||
return self.compiled_callable(*args, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, *args, **kwargs):
|
||||
...
|
||||
|
||||
def bytecode_hook(self, old_code: CodeType, new_code: CodeType):
|
||||
"""Hook to save the compiled bytecode for direct execution."""
|
||||
if old_code is not self.original_code_object:
|
||||
return
|
||||
# code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25
|
||||
frame = sys._getframe()
|
||||
while frame and frame.f_back:
|
||||
frame = frame.f_back
|
||||
code_name = frame.f_code.co_name
|
||||
file_name = frame.f_code.co_filename.split(os.path.sep)[-1]
|
||||
if code_name == "_compile" and file_name == "convert_frame.py":
|
||||
break
|
||||
frame = frame.f_locals["frame"]
|
||||
assert frame.f_code == old_code
|
||||
|
||||
if frame.f_locals["self"] is not self:
|
||||
return
|
||||
|
||||
self.compiled_codes.append(new_code)
|
||||
|
||||
@contextmanager
|
||||
def dispatch_to_code(self, index: int):
|
||||
"""Context manager to dispatch to the compiled code.
|
||||
Why does this work? Because Dynamo guarantees that the compiled
|
||||
bytecode has exactly the same arguments, cell variables, and free
|
||||
variables as the original code. Therefore we can directly switch
|
||||
the code object in the function and call it.
|
||||
|
||||
See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 for more details.
|
||||
""" # noqa
|
||||
self.__class__.forward.__code__ = self.compiled_codes[index]
|
||||
yield
|
||||
self.__class__.forward.__code__ = self.original_code_object
|
||||
Reference in New Issue
Block a user