Upgrade to vllm 0.17.0 corex v4.1 overlay

This commit is contained in:
2026-04-29 19:38:22 +08:00
parent 8fac6062e4
commit 938d0854a5
430 changed files with 35969 additions and 14511 deletions

View File

@@ -282,61 +282,13 @@ class CompilerManager:
maybe_key += f"{compile_range.start}_{compile_range.end}"
maybe_key += f"_subgraph_{graph_index}"
with self.compile_context(compile_range):
# There is a compilation time optimization here.
#
# If the (input metadata, graph, compiler config) are the same, then
# we want to avoid compiling the same artifact again. If we didn't
# do this optimization, the backend compilation (InductorAdaptor or
# InductorStandaloneAdaptor)
# is able to cache hit and produce an artifact faster if it was
# already created, but it is still a duplicate artifact that
# requires unnecessary things e.g. disk IO.
#
# The optimization is: If the backend compilation cache hits,
# then do an early return from the backend compilation and look up
# which of the previous in-memory artifacts we created to reuse.
#
# We implemented this by monkey-patching torch (torch does not
# easily expose the cache_key function), but in the future torch
# should expose the cache_key function that we can just call
# directly before invoking backend compilation.
cache_key = None
orig = torch._functorch._aot_autograd.autograd_cache.autograd_cache_key
def autograd_cache_key(*args, **kwargs):
result = orig(*args, **kwargs)
if result is None:
return None
nonlocal cache_key
cache_key = result[0]
if cache_key in self.loaded_artifacts:
raise StopCompiling()
return result
from unittest.mock import patch
with (
# Graphs that are isometric (different node names but same
# structure) should be treated as the same.
torch._functorch.config.patch(autograd_cache_normalize_inputs=True),
patch(
"torch._functorch._aot_autograd.autograd_cache.autograd_cache_key",
autograd_cache_key,
),
):
try:
compiled_graph, handle = self.compiler.compile(
graph,
example_inputs,
additional_inductor_config,
compile_range,
maybe_key,
)
except StopCompiling:
assert cache_key is not None
return self.loaded_artifacts[cache_key]
if cache_key is not None and compiled_graph is not None:
self.loaded_artifacts[cache_key] = compiled_graph
compiled_graph, handle = self.compiler.compile(
graph,
example_inputs,
additional_inductor_config,
compile_range,
maybe_key,
)
assert compiled_graph is not None, "Failed to compile the graph"
@@ -497,7 +449,7 @@ def wrap_with_cudagraph_if_needed(
# it from the FULL cudagraph runtime mode, no matter it
# is wrapped on a full or piecewise fx graph.
return static_graph_wrapper_class(
runnable=piecewise_backend,
runnable=piecewise_backend.graph.forward,
vllm_config=vllm_config,
runtime_mode=CUDAGraphMode.PIECEWISE,
cudagraph_options=CUDAGraphOptions(
@@ -780,7 +732,7 @@ class VllmBackend:
return standalone_compile_artifacts, sym_shape_indices_map, returns_tuple_map
def configure_post_pass(self) -> None:
# self.pass_manager.configure(self.vllm_config)
self.pass_manager.configure(self.vllm_config)
# Post-grad custom passes are run using the post_grad_custom_post_pass
# hook. If a pass for that hook exists, add it to the pass manager.
@@ -846,7 +798,7 @@ class VllmBackend:
),
)
def __call__(self, graph: fx.GraphModule, example_inputs: Sequence[Any]) -> Any:
def __call__(self, graph: fx.GraphModule, example_inputs: Sequence[Any], **kwargs) -> Any:
from .caching import (
VllmSerializableFunction,
)
@@ -988,7 +940,7 @@ class VllmBackend:
assert not self._called, "VllmBackend can only be called once"
self.graph = graph
self.configure_post_pass()
# self.configure_post_pass()
if self.compilation_config.use_inductor_graph_partition:
# Let Inductor decide partitioning; avoid FX-level pre-splitting.

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import hashlib
import inspect
import os
@@ -144,6 +145,18 @@ class StandaloneCompiledArtifacts:
self.loaded_submodule_store = {}
@contextlib.contextmanager
def patch_pytree_map_over_slice():
pytree._private_register_pytree_node(
slice, lambda x: ([x.start, x.stop, x.step], None), lambda x, c: slice(*x)
)
try:
yield
finally:
pytree._deregister_pytree_node(slice)
class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
"""
A wrapper around a compiled function by vllm. It will forward the tensor
@@ -235,7 +248,10 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
lambda inp: torch.empty_like(inp, device="meta"),
state["example_inputs"],
)
with patch.object(GraphPickler, "reducer_override", _graph_reducer_override):
with (
patch.object(GraphPickler, "reducer_override", _graph_reducer_override),
patch_pytree_map_over_slice(),
):
state["graph_module"] = GraphPickler.dumps(
state["graph_module"], Options(ops_filter=None)
)
@@ -261,7 +277,8 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
state = pickle.loads(data)
fake_mode = FakeTensorMode(shape_env=ShapeEnv())
state["graph_module"] = GraphPickler.loads(state["graph_module"], fake_mode)
with patch_pytree_map_over_slice():
state["graph_module"] = GraphPickler.loads(state["graph_module"], fake_mode)
state["graph_module"].recompile()
state["example_inputs"] = GraphPickler.loads(state["example_inputs"], fake_mode)

View File

@@ -184,6 +184,47 @@ def is_compile_cache_enabled(
)
def _patch_standalone_compile_atomic_save() -> None:
"""Backport of pytorch/pytorch#162432 for torch < 2.10.0.
Patches CompiledArtifact.save() to use write_atomic for binary format,
preventing corrupt cache files when multiple processes compile
concurrently.
"""
from torch._inductor.codecache import write_atomic
from torch._inductor.standalone_compile import CompiledArtifact as cls
if getattr(cls.save, "_vllm_patched", False):
return
original_save = cls.save
def _save(
self: Any, *, path: str, format: Literal["binary", "unpacked"] = "binary"
) -> None:
if format != "binary":
return original_save(self, path=path, format=format)
from torch._dynamo.utils import dynamo_timed
from torch._inductor.codecache import torch_key
from torch.utils._appending_byte_serializer import BytesWriter
with dynamo_timed("CompiledArtifact.save"):
assert self._artifacts is not None
artifact_bytes, cache_info = self._artifacts
assert len(cache_info.aot_autograd_artifacts) == 1, cache_info
key = cache_info.aot_autograd_artifacts[0]
assert not os.path.isdir(path)
writer = BytesWriter()
writer.write_bytes(torch_key())
writer.write_str(key)
writer.write_bytes(artifact_bytes)
write_atomic(path, writer.to_bytes())
_save._vllm_patched = True # type: ignore[attr-defined]
cls.save = _save # type: ignore[assignment]
logger.debug("Patched %s.save for atomic writes (torch < 2.10)", cls.__name__)
class InductorStandaloneAdaptor(CompilerInterface):
"""
The adaptor for the Inductor compiler.
@@ -197,6 +238,8 @@ class InductorStandaloneAdaptor(CompilerInterface):
name = "inductor_standalone"
def __init__(self, save_format: Literal["binary", "unpacked"]) -> None:
if not is_torch_equal_or_newer("2.10.0"):
_patch_standalone_compile_atomic_save()
self.save_format = save_format
def compute_hash(self, vllm_config: VllmConfig) -> str:
@@ -224,7 +267,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
if compiler_config is not None:
current_config.update(compiler_config)
set_inductor_config(current_config, compile_range)
set_functorch_config()
# set_functorch_config()
if compile_range.is_single_size():
dynamic_shapes = "from_example_inputs"
@@ -325,6 +368,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
inductor_compiled_graph = torch._inductor.CompiledArtifact.load(
path=path, format=self.save_format
)
compilation_counter.num_compiled_artifacts_loaded += 1
from torch._inductor.compile_fx import graph_returns_tuple
returns_tuple = graph_returns_tuple(graph)
@@ -395,7 +439,7 @@ class InductorAdaptor(CompilerInterface):
current_config["fx_graph_remote_cache"] = False
set_inductor_config(current_config, compile_range)
set_functorch_config()
# set_functorch_config()
# inductor can inplace modify the graph, so we need to copy it
# see https://github.com/pytorch/pytorch/issues/138980

View File

@@ -29,6 +29,8 @@ class CompilationCounter:
num_cache_entries_updated: int = 0
# The number of standalone_compile compiled artifacts saved
num_compiled_artifacts_saved: int = 0
# The number of standalone_compile compiled artifacts loaded from cache
num_compiled_artifacts_loaded: int = 0
# Number of times a model was loaded with CompilationMode.STOCK_TORCH_COMPILE
stock_torch_compile_count: int = 0

View File

@@ -21,6 +21,7 @@ from vllm.model_executor.offloader.base import get_offloader
from vllm.platforms import current_platform
from vllm.utils.torch_utils import current_stream, weak_ref_tensors
from vllm.sequence import IntermediateTensors
logger = init_logger(__name__)
@@ -204,14 +205,14 @@ class CUDAGraphWrapper:
def unwrap(self) -> Callable[..., Any]:
# in case we need to access the original runnable.
return self.runnable
def weak_ref_tensors_with_intermediate(self, output):
if isinstance(output, IntermediateTensors):
intermediate_states = IntermediateTensors(
tensors={key: weak_ref_tensors(value) for key, value in output.tensors.items()})
return intermediate_states
return weak_ref_tensors(output)
def __call__(self, *args: Any, **kwargs: Any) -> Any | None:
forward_context = get_forward_context()
batch_descriptor = forward_context.batch_descriptor
@@ -298,12 +299,10 @@ class CUDAGraphWrapper:
# the last graph in piecewise cuadgraph mode, because
# the output of the last graph will not be used by
# any other cuda graph.
# output = weak_ref_tensors(output)
output = self.weak_ref_tensors_with_intermediate(output)
# here we always use weak ref for the output
# to save memory
# entry.output = weak_ref_tensors(output)
entry.output = self.weak_ref_tensors_with_intermediate(output)
entry.cudagraph = cudagraph

View File

@@ -53,7 +53,7 @@ class GEMMReduceScatterPattern(BasePattern):
gemm_rs = torch.ops.symm_mem.fused_matmul_reduce_scatter(
mul,
mm_weight,
"avg",
"sum",
scatter_dim=0,
group_name=self.tp.device_group.group_name,
)
@@ -150,7 +150,7 @@ class ScaledMMReduceScatterPattern(BasePattern):
mat2,
scale_a,
scale_b,
"avg",
"sum",
scatter_dim, # orig_scatter_dim
scatter_dim, # scatter_dim_after_maybe_reshape
self.tp.device_group.group_name,
@@ -285,7 +285,7 @@ class CutlassScaledMMReduceScatterPattern(BasePattern):
mat2,
scale_a,
scale_b,
"avg",
"sum",
scatter_dim, # orig_scatter_dim
scatter_dim, # scatter_dim_after_maybe_reshape
self.tp.device_group.group_name,

View File

@@ -5,7 +5,6 @@ import torch
import torch._inductor.pattern_matcher as pm
from torch import fx
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._ops import OpOverload
import vllm.model_executor.layers.quantization.utils.fp8_utils # noqa: F401
from vllm._aiter_ops import rocm_aiter_ops
@@ -15,6 +14,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
QuantKey,
ScaleDesc,
kFp8Dynamic128Sym,
)
from vllm.platforms import current_platform
@@ -312,7 +312,9 @@ class RocmAiterRMSNormQuantFusionPass(VllmPatternMatcherPass):
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
logger.debug(
"%s Replaced %s patterns", self.__class__.__name__, self.matched_count
)
def uuid(self) -> str:
fusion_patterns = [
@@ -332,9 +334,11 @@ class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern):
FUSED_SILU_MUL_QUANT_OP = rocm_aiter_ops.get_act_mul_fused_fp8_group_quant_op()
def __init__(self, quant_op: OpOverload) -> None:
def __init__(self) -> None:
self.silu_and_mul_matcher = MatcherSiluAndMul()
self.quant_op = quant_op
self.quant_matcher = MatcherQuantFP8(
quant_key=kFp8Dynamic128Sym, match_rocm_aiter=True
)
def get_inputs(self) -> list[torch.Tensor]:
return [
@@ -346,7 +350,7 @@ class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern):
input: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
at1 = self.silu_and_mul_matcher(input)
at2 = self.quant_op(at1, 128)
at2 = self.quant_matcher(at1)
return at2[0], at2[1]
def replacement(
@@ -370,11 +374,6 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
"""
AITER_GROUP_FP8_QUANT_OP = rocm_aiter_ops.get_group_quant_op()
TRITON_GROUP_FP8_QUANT_OP = torch.ops.vllm.triton_per_token_group_quant_fp8.default
QUANT_OPS = [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]
@enable_fake_mode
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
@@ -383,8 +382,7 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
pass_name="rocm_aiter_silu_mul_fp8_group_quant_fusion_pass"
)
for quant_op in self.QUANT_OPS:
AiterSiluMulFp8GroupQuantPattern(quant_op).register(self.patterns)
AiterSiluMulFp8GroupQuantPattern().register(self.patterns)
self.dump_patterns(config, self.patterns)

View File

@@ -18,7 +18,6 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym,
)
from vllm.platforms import current_platform
from ..inductor_pass import enable_fake_mode
from ..utility.noop_elimination import NoOpEliminationPass
@@ -215,9 +214,6 @@ class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
)
FP8_DTYPE = current_platform.fp8_dtype()
class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
def __init__(
self,

View File

@@ -37,6 +37,14 @@ class FixFunctionalizationPass(VllmInductorPass):
self.nodes_to_remove: list[torch.fx.Node] = []
count = 0
rope_targets = [torch.ops._C.rotary_embedding.default]
if hasattr(torch.ops.vllm, "rocm_aiter_triton_rotary_embedding"):
rope_targets.append(
torch.ops.vllm.rocm_aiter_triton_rotary_embedding.default
)
for node in graph.nodes:
if not is_func(node, auto_functionalized):
continue # Avoid deep if-elif nesting
@@ -44,7 +52,7 @@ class FixFunctionalizationPass(VllmInductorPass):
kwargs = node.kwargs
at_target = node.args[0]
if at_target == torch.ops._C.rotary_embedding.default:
if at_target in rope_targets:
query = kwargs["query"]
key = kwargs["key"]
getitem_nodes = self.getitem_users(node)

View File

@@ -298,18 +298,18 @@ class PiecewiseBackend:
else list(args)
)
with (
torch._functorch.config.patch("bundled_autograd_cache", True),
):
range_entry.runnable = self.vllm_backend.compiler_manager.compile(
self.graph,
args_list,
self.vllm_backend.inductor_config,
self.compilation_config,
compile_range=range_entry.compile_range,
graph_index=self.piecewise_compile_index,
num_graphs=self.total_piecewise_compiles,
)
# with (
# torch._functorch.config.patch("bundled_autograd_cache", True),
# ):
range_entry.runnable = self.vllm_backend.compiler_manager.compile(
self.graph,
args_list,
self.vllm_backend.inductor_config,
self.compilation_config,
compile_range=range_entry.compile_range,
graph_index=self.piecewise_compile_index,
num_graphs=self.total_piecewise_compiles,
)
range_entry.compiled = True
self.to_be_compiled_ranges.remove(range_entry.compile_range)

View File

@@ -319,3 +319,52 @@ class TorchCompileWithNoGuardsWrapper:
yield
finally:
self.__class__.forward.__code__ = original
def reset_compile_wrapper(model: torch.nn.Module) -> None:
"""
Clean up compiled model and captured CUDA graphs for elastic EP.
"""
if not isinstance(model, TorchCompileWithNoGuardsWrapper) and hasattr(
model, "model"
):
model = model.model
if not isinstance(model, TorchCompileWithNoGuardsWrapper):
return
# model.do_not_compile is set by the @support_torch_compile decorator
if hasattr(model, "do_not_compile") and model.do_not_compile:
return
from vllm.compilation.counter import compilation_counter
# reset the compilation counter
compilation_counter.num_models_seen = 0
compilation_counter.num_graphs_seen = 0
compilation_counter.num_piecewise_graphs_seen = 0
compilation_counter.num_piecewise_capturable_graphs_seen = 0
compilation_counter.num_backend_compilations = 0
compilation_counter.num_gpu_runner_capture_triggers = 0
compilation_counter.num_cudagraph_captured = 0
compilation_counter.num_inductor_compiles = 0
compilation_counter.num_eager_compiles = 0
compilation_counter.num_cache_entries_updated = 0
compilation_counter.num_compiled_artifacts_saved = 0
compilation_counter.stock_torch_compile_count = 0
# Clear the AOT compiled function so the model is forced to
# recompile on the next call. Without this, decorators.py
# __call__ uses the stale aot_compiled_fn whose torchinductor
# kernels have old parameters (expert_map size for example)
# baked in as compile-time constants.
if hasattr(model, "aot_compiled_fn"):
model.aot_compiled_fn = None
if hasattr(model, "was_aot_compile_fn_loaded_from_disk"):
model.was_aot_compile_fn_loaded_from_disk = False
# Reset the cache_dir so VllmBackend recomputes the hash
# (data_parallel_size changed, so the config hash differs).
compilation_config = model.vllm_config.compilation_config
compilation_config.cache_dir = ""
compilation_config.local_cache_dir = ""
model.__class__.forward.__code__ = model.original_code_object()
TorchCompileWithNoGuardsWrapper.__init__(model)