Add minimal vLLM 0.16.1 build repo for BI-V150

This commit is contained in:
2026-04-18 10:56:22 +08:00
commit d69657327e
1895 changed files with 615301 additions and 0 deletions

View File

@@ -0,0 +1,301 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import operator
from collections.abc import Iterable
import torch
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from vllm.logger import init_logger
from vllm.platforms import current_platform
from ..fx_utils import is_func
from ..vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
class FixFunctionalizationPass(VllmInductorPass):
"""
This pass defunctionalizes certain nodes to avoid redundant tensor copies.
After this pass, DCE (dead-code elimination) should never be run,
as de-functionalized nodes may appear as dead code.
To add new nodes to defunctionalize, add to the if-elif chain in __call__.
"""
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph) -> None:
# XPU does not support auto-functionalization yet.
# Will enable this when switch to vllm-xpu-kernels.
if current_platform.is_xpu():
logger.debug(
"XPU platform does not support fix functionalizationpass currently."
)
return
self.nodes_to_remove: list[torch.fx.Node] = []
count = 0
for node in graph.nodes:
if not is_func(node, auto_functionalized):
continue # Avoid deep if-elif nesting
kwargs = node.kwargs
at_target = node.args[0]
if at_target == torch.ops._C.rotary_embedding.default:
query = kwargs["query"]
key = kwargs["key"]
getitem_nodes = self.getitem_users(node)
if (
is_func(query, operator.getitem)
and is_func(key, operator.getitem)
and query.args[0] == key.args[0]
and is_func(query.args[0], torch.ops.aten.split_with_sizes.default)
and all(
is_func(user, torch.ops.aten.slice_scatter.default)
for getitem_node in getitem_nodes.values()
for user in getitem_node.users
)
):
# Pattern where query and key are slices of an mm_node.
# While functionalized, results at [1] and [2] are scattered
# back into mm_node. So after de-functionalization, we can
# just use mm_node directly.
mm_node = query.args[0].args[0]
for user in getitem_nodes.values():
for user_of_getitem in user.users:
if is_func(
user_of_getitem, torch.ops.aten.slice_scatter.default
):
user_of_getitem.replace_all_uses_with(mm_node)
self._remove(user_of_getitem)
self._remove(user)
self.insert_defunctionalized(graph, node)
self._remove(node)
else:
# Directly replace the auto_functionalize(rotary_embedding)
# with the inplace rotary_embedding. In theory, we shouldn't
# do this blindly, but in practice in vLLM it's ok. The best
# solution is to use auto_functionalization_v2 and then use
# inductor's builtin defunctionalization (reinplacing) pass.
mutated_args = {1: "query", 2: "key"}
self.defunctionalize(graph, node, mutated_args)
# rms_norm replacements avoid the most copies for LLaMa.
elif at_target == torch.ops._C.fused_add_rms_norm.default:
mutated_args = {1: "input", 2: "residual"}
self.defunctionalize(graph, node, mutated_args)
elif at_target == torch.ops._C.fused_add_rms_norm_static_fp8_quant.default: # noqa: E501
mutated_args = {1: "result", 2: "residual"}
self.defunctionalize(graph, node, mutated_args)
elif at_target == torch.ops._C.rms_norm_dynamic_per_token_quant.default: # noqa: E501
mutated_args = {1: "result", 2: "scale", 3: "residual"}
self.defunctionalize(graph, node, mutated_args)
elif at_target in [
torch.ops._C.rms_norm.default,
torch.ops._C.rms_norm_static_fp8_quant.default,
]:
mutated_args = {1: "result"}
self.defunctionalize(graph, node, mutated_args)
elif (
hasattr(torch.ops.vllm, "flashinfer_trtllm_fused_allreduce_norm")
and at_target
== torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default
):
mutated_args = {
1: "allreduce_in",
2: "residual",
3: "norm_out",
4: "quant_out",
5: "scale_out",
}
self.defunctionalize(graph, node, mutated_args)
# For some reason we need to specify the args for both
# silu_and_mul and silu_and_mul_quant. The kwargs
# pathway gets the wrong answer.
elif at_target == torch.ops._C.silu_and_mul.default:
mutated_args = {1: "result"}
self.defunctionalize(
graph, node, mutated_args, args=("result", "input")
)
elif at_target == torch.ops._C.silu_and_mul_quant.default:
mutated_args = {1: "result"}
self.defunctionalize(
graph, node, mutated_args, args=("result", "input", "scale")
)
elif (
hasattr(torch.ops._C, "silu_and_mul_nvfp4_quant")
and at_target == torch.ops._C.silu_and_mul_nvfp4_quant.default
):
mutated_args = {1: "result", 2: "result_block_scale"}
self.defunctionalize(
graph,
node,
mutated_args,
args=(
"result",
"result_block_scale",
"input",
"input_global_scale",
),
)
# Defunctionalize fused_qk_norm_rope to remove higher-order wrapper.
elif at_target == torch.ops._C.fused_qk_norm_rope.default:
mutated_args = {1: "qkv"}
args = (
"qkv",
"num_heads_q",
"num_heads_k",
"num_heads_v",
"head_dim",
"eps",
"q_weight",
"k_weight",
"cos_sin_cache",
"is_neox",
"position_ids",
)
self.defunctionalize(graph, node, mutated_args=mutated_args, args=args)
elif (
hasattr(torch.ops.vllm, "fused_rope_and_unified_kv_cache_update")
and at_target
== torch.ops.vllm.fused_rope_and_unified_kv_cache_update.default
):
mutated_args = {
1: "query",
2: "key",
}
self.defunctionalize(graph, node, mutated_args=mutated_args)
# only used for test_functionalization::TestFunctionWithMutatedArgsAndReturn
elif (
hasattr(torch.ops.vllm, "function_with_mutated_args_and_return")
and at_target
== torch.ops.vllm.function_with_mutated_args_and_return.default
):
mutated_args = {1: "x"}
self.defunctionalize(graph, node, mutated_args=mutated_args)
else:
continue # skip the count
count += 1
self.dump_graph(graph, "before_cleanup")
# Remove the nodes all at once
count_removed = len(self.nodes_to_remove)
for node in self.nodes_to_remove:
graph.erase_node(node)
logger.debug(
"De-functionalized %s nodes, removed %s nodes", count, count_removed
)
self.nodes_to_remove.clear()
def _remove(self, node_or_nodes: torch.fx.Node | Iterable[torch.fx.Node]) -> None:
"""
Stage a node (or nodes) for removal at the end of the pass.
"""
if isinstance(node_or_nodes, torch.fx.Node):
self.nodes_to_remove.append(node_or_nodes)
else:
self.nodes_to_remove.extend(node_or_nodes)
def defunctionalize(
self,
graph: torch.fx.Graph,
node: torch.fx.Node,
mutated_args: dict[int, torch.fx.Node | str],
args: tuple[torch.fx.Node | str, ...] | None = None,
) -> None:
"""
De-functionalize a node by replacing it with a call to the original.
It also replaces the getitem users with the mutated arguments.
See replace_users_with_mutated_args and insert_defunctionalized.
"""
self.replace_users_with_mutated_args(node, mutated_args)
self.insert_defunctionalized(graph, node, args=args)
self._remove(node)
def replace_users_with_mutated_args(
self, node: torch.fx.Node, mutated_args: dict[int, torch.fx.Node | str]
) -> None:
"""
Replace mutated getitem users of the auto-functionalized node with the
mutated arguments.
:param node: The auto-functionalized node
:param mutated_args: The mutated arguments, indexed by getitem index.
If the value of an arg is a string, `node.kwargs[arg]` is used.
"""
for idx, user in self.getitem_users(node).items():
# Some functionalized nodes may return both a result at getitem[0]
# as well as mutated args at getitem[1:...]
if idx == 0:
assert idx not in mutated_args, (
f"result at getitem[0] should not be in mutated_args for {node}"
)
continue
arg = mutated_args[idx]
arg = node.kwargs[arg] if isinstance(arg, str) else arg
user.replace_all_uses_with(arg)
self._remove(user)
def getitem_users(self, node: torch.fx.Node) -> dict[int, torch.fx.Node]:
"""
Returns the operator.getitem users of the auto-functionalized node,
indexed by the index they are getting.
"""
users = {}
for user in node.users:
if is_func(user, operator.getitem):
idx = user.args[1]
users[idx] = user
return users
def insert_defunctionalized(
self,
graph: torch.fx.Graph,
node: torch.fx.Node,
args: tuple[torch.fx.Node | str, ...] | None = None,
) -> None:
"""
Insert a new defunctionalized node into the graph before node.
If one of the kwargs is 'out', provide args directly,
as node.kwargs cannot be used.
See https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351
:param graph: Graph to insert the defunctionalized node into
:param node: The auto-functionalized node to defunctionalize
:param args: If we cannot use kwargs, specify args directly.
If an arg is a string, `node.kwargs[arg]` is used.
""" # noqa: E501
assert is_func(node, auto_functionalized), (
f"node must be auto-functionalized, is {node} instead"
)
# Create a new call to the original function
with graph.inserting_before(node):
function = node.args[0]
if args is None:
fn_node = graph.call_function(function, kwargs=node.kwargs)
else:
# Args passed as strings refer to items in node.kwargs
args = tuple(
node.kwargs[arg] if isinstance(arg, str) else arg for arg in args
)
fn_node = graph.call_function(function, args=args)
# If the function returns a value as well as mutating args inplace,
# the functionalized node will have a getitem[0] user that holds this value
# Replace getitem[0] user of the auto-functionalized node
# with the new defunctionalized node directly if it exists
users = self.getitem_users(node)
if 0 in users:
user = users[0]
user.replace_all_uses_with(fn_node)
self._remove(user)

View File

@@ -0,0 +1,130 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
import torch.fx
from torch import SymInt
from torch.fx.experimental.symbolic_shapes import statically_known_true
from vllm.logger import init_logger
from ..fx_utils import is_func
from ..vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
class NoOpEliminationPass(VllmInductorPass):
"""
This is an inductor pass that removes redundant reshape/slice operations.
It is required for RMSNorm-quant fusion to work properly.
That's because apply_fp8_linear adds a reshape, which is redundant
in the 2D-case. Additionally, torch internal no-op elimination pass does
not handle certain slice variants.
Cases handled:
1. A chain of reshapes is equivalent to the last reshape called on the
base tensor (input of the first reshape).
2. A reshape that produces the shape of the input is redundant
3. A slice that produces the shape of the input is redundant
Example graph 1:
mul_1: "f16[s0, 4096]" = ...
view_1: "f16[s0, 128, 32]" = torch.reshape(mul_1, [-1, 128, 32])
view_2: "f16[s0, 4096]" = torch.reshape(view_2, [-1, 4096])
view_3: "f16[s0, 128, 32]" = torch.reshape(view_3, [-1, 128, 32])
Can be replaced with:
mul_1: "f16[s0, 4096]" = ...
view_3: "f16[s0, 128, 32]" = ...
Example graph 2:
getitem_1: "f16[s0, 4096]" = ...
view_1: "f16[s0, 4096]" = torch.reshape(getitem_1, [-1, 4096])
at = auto_functionalized(static_scaled_fp8_quant, input = view_1, ...)
out: "f8e4m3fn[s0, 4096]" = at[1]
Can be replaced with:
getitem_1: "f16[s0, 4096]" = ...
at = auto_functionalized(static_scaled_fp8_quant, input = getitem_1, ...)
out: "f8e4m3fn[s0, 4096]" = at[1]
Example graph 3:
arg0: "s0" = SymInt(s0)
scaled_mm: "f16[s0, 4096]" = ...
slice_1: "f16[s0, 4096]" = torch.slice(scaled_mm, -1, 0, arg0)
at = auto_functionalized(fused_add_rms_norm, input = slice_1, ...)
out: "f16[s0, 4096]" = torch.slice_scatter(scaled_mm, at[1], 0, 0, arg0)
Can be replaced with:
arg0: "s0" = SymInt(s0)
scaled_mm: "f16[s0, 4096]" = ...
at = auto_functionalized(fused_add_rms_norm, input = scaled_mm, ...)
out: "f16[s0, 4096]" = at[1]
"""
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph) -> None:
count = 0
# Remove no-op reshapes/views:
for node in graph.nodes:
if is_func(node, torch.ops.aten.reshape.default):
# Case 1: rewrite reshape chains to reshapes on the base tensor
input = node.args[0]
# If the input is a reshape, rebind to that node
if is_func(input, torch.ops.aten.reshape.default):
# The new input is guaranteed not to be a reshape,
# because we process nodes in order
node.update_arg(0, input.args[0])
if len(input.users) == 0:
graph.erase_node(input)
count += 1
# remove reshape/slice if it produces the original shape
if is_func(node, torch.ops.aten.reshape.default) or is_func(
node, torch.ops.aten.slice.Tensor
):
input = node.args[0]
input_shape = input.meta["val"].shape
output_shape = node.meta["val"].shape
if self.all_dims_equivalent(input_shape, output_shape):
node.replace_all_uses_with(input)
graph.erase_node(node)
count += 1
elif is_func(node, torch.ops.aten.slice_scatter.default):
base, view, dim_index, start, end = node.args[:5]
base_shape = base.meta["val"].shape
view_shape = view.meta["val"].shape
if self.all_dims_equivalent(base_shape, view_shape):
node.replace_all_uses_with(view)
graph.erase_node(node)
count += 1
logger.debug("Removed %s no-op reshapes and slices", count)
# ---------------------- Shape comparison helpers ----------------------
def dims_equivalent(self, dim: int | SymInt, i_dim: int | SymInt) -> bool:
"""
This function checks if two dimensions are equivalent.
:param dim: The dimension arg to reshape/slice
:param i_dim: The corresponding dimension in the input tensor
:return: Are the dimensions equivalent?
There are two cases in which the dimensions are equivalent:
1. The dimensions are equal (both integers)
2. The dimensions both correspond to the same SymInt
"""
# Case 1
return statically_known_true(dim == i_dim) # type: ignore[no-any-return]
def all_dims_equivalent(
self, dims: Iterable[int | SymInt], i_dims: Iterable[int | SymInt]
) -> bool:
dims_ = list(dims)
i_dims_ = list(i_dims)
if len(dims_) != len(i_dims_):
# Different ranks can't be equivalent
return False
return all(self.dims_equivalent(s, i_s) for s, i_s in zip(dims, i_dims))

View File

@@ -0,0 +1,21 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from torch import fx
from ..vllm_inductor_pass import VllmInductorPass
class PostCleanupPass(VllmInductorPass):
"""
This pass performs cleanup after custom passes.
It topologically sorts the graph and removes unused nodes.
This is needed because the pattern matcher does not guarantee producing
a topologically sorted graph, and there may be unused nodes left around.
"""
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
from torch._inductor.pattern_matcher import stable_topological_sort
stable_topological_sort(graph)
graph.eliminate_dead_code()

View File

@@ -0,0 +1,138 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Replace ``slice_scatter`` and ``split_with_sizes`` nodes with a single
assignment if there are no users for the inplace tensor written to by
the slice_scatter call.
The inplace rotary_embedding custom op takes in mutable query and key inputs
that are split+getitem outputs of a single qkv tensor.
When functionalized, we fetch the rotated query and key from the functionalized op
using `getitem` calls. However, we also write to the qkv tensor inplace using a
`slice_scatter`, then split the inplace tensor to get the output tensors again.
Instead, if the inplace tensor has no subsequent users, we can just replace the
`slice_scatter` and `split_with_sizes` nodes with the `getitem` calls.
This is already done in fix_functionalization::FixFunctionalizationPass, but
writing a custom pass for it before defunctionalization allows matching against the
qkv split+rotary_embedding subpattern as part of e.g. the RoPE+KVCache fusion pass.
"""
import operator
import torch
from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from vllm.logger import init_logger
from ..fx_utils import is_func
from ..vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
class ScatterSplitReplacementPass(VllmInductorPass):
"""Replace getitem+slice_scatter+split nodes with a single getitem when
the inplace subtensor written to by the slice_scatter has no other users.
Here's an example graph with q_size = 512, kv_size = 64:
split_with_sizes_1 = torch.ops.aten.split_with_sizes.default(qkv, (512, 64, 64), -1)
at = auto_functionalized(torch.ops._C.rotary_embedding.default(positions, q, k))
q = operator.getitem(at, 1)
k = operator.getitem(at, 2)
torch.ops.aten.slice_scatter.default(qkv, q, [0, 512], -1)
torch.ops.aten.slice_scatter.default(qkv, k, [512, 512 + 64], -1)
split_with_sizes_2 = torch.ops.aten.split_with_sizes.default(qkv, (512, 64, 64), -1)
q = operator.getitem(split_with_sizes_2, 0)
k = operator.getitem(split_with_sizes_2, 1)
v = operator.getitem(split_with_sizes_2, 2)
After this pass, this sequence of nodes is replaced with:
split_with_sizes_1 = torch.ops.aten.split_with_sizes.default(qkv, (512, 64, 64), -1)
at = auto_functionalized(torch.ops._C.rotary_embedding.default(positions, q, k))
q = operator.getitem(at, 1)
k = operator.getitem(at, 2)
v = operator.getitem(split_with_sizes_1, 2)
"""
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
count = 0
target_ops = [torch.ops._C.rotary_embedding.default]
if hasattr(torch.ops.vllm, "rocm_aiter_triton_rotary_embedding"):
target_ops.append(torch.ops.vllm.rocm_aiter_triton_rotary_embedding.default)
for node in graph.nodes:
if not is_func(node, auto_functionalized):
continue
kwargs = node.kwargs
at_target = node.args[0]
if at_target in target_ops:
query = kwargs["query"]
key = kwargs["key"]
getitem_nodes = {}
for user in node.users:
if is_func(user, operator.getitem):
getitem_nodes[user.args[1]] = user
if (
is_func(query, operator.getitem)
and is_func(key, operator.getitem)
and query.args[0] == key.args[0]
and is_func(query.args[0], torch.ops.aten.split_with_sizes.default)
and all(
is_func(user, torch.ops.aten.slice_scatter.default)
for getitem_node in getitem_nodes.values()
for user in getitem_node.users
)
):
# Pattern where query and key are slices of a qkv tensor.
# While functionalized, results at [1] and [2] are scattered
# back into qkv, then split again to get query and key.
# If the inplace tensor has no other users, we can replace
# the slice_scatter+split nodes with the original results.
for user in getitem_nodes[1].users:
slice_scatter_1_node = user
if not is_func(
slice_scatter_1_node, torch.ops.aten.slice_scatter.default
):
continue
for user in getitem_nodes[2].users:
slice_scatter_2_node = user
if not is_func(
slice_scatter_2_node, torch.ops.aten.slice_scatter.default
):
continue
for user in slice_scatter_2_node.users:
split_node = user
if not is_func(split_node, torch.ops.aten.split_with_sizes.default):
continue
split_getitem_users = {}
for user in split_node.users:
if is_func(user, operator.getitem):
split_getitem_users[user.args[1]] = user
# Replace query node
split_getitem_users[0].replace_all_uses_with(getitem_nodes[1])
graph.erase_node(split_getitem_users[0])
# Replace key node
split_getitem_users[1].replace_all_uses_with(getitem_nodes[2])
graph.erase_node(split_getitem_users[1])
# Redirect value node to original qkv tensor
split_getitem_users[2].replace_input_with(split_node, query.args[0])
# Erase unused nodes
graph.erase_node(split_node)
graph.erase_node(slice_scatter_2_node)
graph.erase_node(slice_scatter_1_node)
count += 1
logger.debug("Eliminated %d slice_scatter+split nodes", count)

View File

@@ -0,0 +1,70 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Coalesce duplicate ``split_with_sizes`` nodes that operate on the same
input tensor with the same split sizes.
On certain hardware/dtype combinations (e.g. B200 + FP8) the Inductor
graph may contain multiple ``split_with_sizes`` calls on the same tensor
that CSE fails to merge. This pass detects and replaces the duplicates
so that downstream pattern-matching passes (e.g. QK-Norm+RoPE fusion)
see a single split node with all users attached.
See also:
- vLLM #33295 (original issue)
- PyTorch #174472 (upstream CSE gap)
"""
import operator
import torch
from torch import fx
from vllm.logger import init_logger
from ..fx_utils import is_func
from ..vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
class SplitCoalescingPass(VllmInductorPass):
"""Replace duplicate ``split_with_sizes`` nodes with a single canonical
node when they share the same input tensor and split sizes."""
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
count = 0
# Map from input tensor node -> list of split nodes seen so far.
split_nodes: dict[fx.Node, list[fx.Node]] = {}
for node in graph.nodes:
if not is_func(node, torch.ops.aten.split_with_sizes.default):
continue
if not all(is_func(user, operator.getitem) for user in node.users):
continue
arg_node, split_sizes = node.args[:2]
if arg_node not in split_nodes:
split_nodes[arg_node] = [node]
continue
# Find existing node with same split_sizes
canonical = next(
(
n
for n in split_nodes[arg_node]
if list(n.args[1]) == list(split_sizes)
),
None,
)
if canonical is not None:
node.replace_all_uses_with(canonical)
graph.erase_node(node)
count += 1
else:
split_nodes[arg_node].append(node)
logger.debug("Coalesced %d duplicate split_with_sizes nodes", count)