71 lines
2.2 KiB
Python
71 lines
2.2 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""
|
|
Coalesce duplicate ``split_with_sizes`` nodes that operate on the same
|
|
input tensor with the same split sizes.
|
|
|
|
On certain hardware/dtype combinations (e.g. B200 + FP8) the Inductor
|
|
graph may contain multiple ``split_with_sizes`` calls on the same tensor
|
|
that CSE fails to merge. This pass detects and replaces the duplicates
|
|
so that downstream pattern-matching passes (e.g. QK-Norm+RoPE fusion)
|
|
see a single split node with all users attached.
|
|
|
|
See also:
|
|
- vLLM #33295 (original issue)
|
|
- PyTorch #174472 (upstream CSE gap)
|
|
"""
|
|
|
|
import operator
|
|
|
|
import torch
|
|
from torch import fx
|
|
|
|
from vllm.logger import init_logger
|
|
|
|
from ..fx_utils import is_func
|
|
from ..vllm_inductor_pass import VllmInductorPass
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class SplitCoalescingPass(VllmInductorPass):
|
|
"""Replace duplicate ``split_with_sizes`` nodes with a single canonical
|
|
node when they share the same input tensor and split sizes."""
|
|
|
|
@VllmInductorPass.time_and_log
|
|
def __call__(self, graph: fx.Graph) -> None:
|
|
count = 0
|
|
|
|
# Map from input tensor node -> list of split nodes seen so far.
|
|
split_nodes: dict[fx.Node, list[fx.Node]] = {}
|
|
|
|
for node in graph.nodes:
|
|
if not is_func(node, torch.ops.aten.split_with_sizes.default):
|
|
continue
|
|
if not all(is_func(user, operator.getitem) for user in node.users):
|
|
continue
|
|
|
|
arg_node, split_sizes = node.args[:2]
|
|
|
|
if arg_node not in split_nodes:
|
|
split_nodes[arg_node] = [node]
|
|
continue
|
|
|
|
# Find existing node with same split_sizes
|
|
canonical = next(
|
|
(
|
|
n
|
|
for n in split_nodes[arg_node]
|
|
if list(n.args[1]) == list(split_sizes)
|
|
),
|
|
None,
|
|
)
|
|
if canonical is not None:
|
|
node.replace_all_uses_with(canonical)
|
|
graph.erase_node(node)
|
|
count += 1
|
|
else:
|
|
split_nodes[arg_node].append(node)
|
|
|
|
logger.debug("Coalesced %d duplicate split_with_sizes nodes", count)
|