139 lines
6.0 KiB
Python
139 lines
6.0 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""
|
|
Replace ``slice_scatter`` and ``split_with_sizes`` nodes with a single
|
|
assignment if there are no users for the inplace tensor written to by
|
|
the slice_scatter call.
|
|
|
|
The inplace rotary_embedding custom op takes in mutable query and key inputs
|
|
that are split+getitem outputs of a single qkv tensor.
|
|
When functionalized, we fetch the rotated query and key from the functionalized op
|
|
using `getitem` calls. However, we also write to the qkv tensor inplace using a
|
|
`slice_scatter`, then split the inplace tensor to get the output tensors again.
|
|
Instead, if the inplace tensor has no subsequent users, we can just replace the
|
|
`slice_scatter` and `split_with_sizes` nodes with the `getitem` calls.
|
|
|
|
This is already done in fix_functionalization::FixFunctionalizationPass, but
|
|
writing a custom pass for it before defunctionalization allows matching against the
|
|
qkv split+rotary_embedding subpattern as part of e.g. the RoPE+KVCache fusion pass.
|
|
"""
|
|
|
|
import operator
|
|
|
|
import torch
|
|
from torch import fx
|
|
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
|
|
|
from vllm.logger import init_logger
|
|
|
|
from ..fx_utils import is_func
|
|
from ..vllm_inductor_pass import VllmInductorPass
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class ScatterSplitReplacementPass(VllmInductorPass):
|
|
"""Replace getitem+slice_scatter+split nodes with a single getitem when
|
|
the inplace subtensor written to by the slice_scatter has no other users.
|
|
|
|
Here's an example graph with q_size = 512, kv_size = 64:
|
|
split_with_sizes_1 = torch.ops.aten.split_with_sizes.default(qkv, (512, 64, 64), -1)
|
|
at = auto_functionalized(torch.ops._C.rotary_embedding.default(positions, q, k))
|
|
q = operator.getitem(at, 1)
|
|
k = operator.getitem(at, 2)
|
|
torch.ops.aten.slice_scatter.default(qkv, q, [0, 512], -1)
|
|
torch.ops.aten.slice_scatter.default(qkv, k, [512, 512 + 64], -1)
|
|
split_with_sizes_2 = torch.ops.aten.split_with_sizes.default(qkv, (512, 64, 64), -1)
|
|
q = operator.getitem(split_with_sizes_2, 0)
|
|
k = operator.getitem(split_with_sizes_2, 1)
|
|
v = operator.getitem(split_with_sizes_2, 2)
|
|
|
|
After this pass, this sequence of nodes is replaced with:
|
|
split_with_sizes_1 = torch.ops.aten.split_with_sizes.default(qkv, (512, 64, 64), -1)
|
|
at = auto_functionalized(torch.ops._C.rotary_embedding.default(positions, q, k))
|
|
q = operator.getitem(at, 1)
|
|
k = operator.getitem(at, 2)
|
|
v = operator.getitem(split_with_sizes_1, 2)
|
|
"""
|
|
|
|
@VllmInductorPass.time_and_log
|
|
def __call__(self, graph: fx.Graph) -> None:
|
|
count = 0
|
|
|
|
target_ops = [torch.ops._C.rotary_embedding.default]
|
|
if hasattr(torch.ops.vllm, "rocm_aiter_triton_rotary_embedding"):
|
|
target_ops.append(torch.ops.vllm.rocm_aiter_triton_rotary_embedding.default)
|
|
|
|
for node in graph.nodes:
|
|
if not is_func(node, auto_functionalized):
|
|
continue
|
|
|
|
kwargs = node.kwargs
|
|
at_target = node.args[0]
|
|
|
|
if at_target in target_ops:
|
|
query = kwargs["query"]
|
|
key = kwargs["key"]
|
|
getitem_nodes = {}
|
|
for user in node.users:
|
|
if is_func(user, operator.getitem):
|
|
getitem_nodes[user.args[1]] = user
|
|
|
|
if (
|
|
is_func(query, operator.getitem)
|
|
and is_func(key, operator.getitem)
|
|
and query.args[0] == key.args[0]
|
|
and is_func(query.args[0], torch.ops.aten.split_with_sizes.default)
|
|
and all(
|
|
is_func(user, torch.ops.aten.slice_scatter.default)
|
|
for getitem_node in getitem_nodes.values()
|
|
for user in getitem_node.users
|
|
)
|
|
):
|
|
# Pattern where query and key are slices of a qkv tensor.
|
|
# While functionalized, results at [1] and [2] are scattered
|
|
# back into qkv, then split again to get query and key.
|
|
# If the inplace tensor has no other users, we can replace
|
|
# the slice_scatter+split nodes with the original results.
|
|
for user in getitem_nodes[1].users:
|
|
slice_scatter_1_node = user
|
|
if not is_func(
|
|
slice_scatter_1_node, torch.ops.aten.slice_scatter.default
|
|
):
|
|
continue
|
|
|
|
for user in getitem_nodes[2].users:
|
|
slice_scatter_2_node = user
|
|
if not is_func(
|
|
slice_scatter_2_node, torch.ops.aten.slice_scatter.default
|
|
):
|
|
continue
|
|
|
|
for user in slice_scatter_2_node.users:
|
|
split_node = user
|
|
if not is_func(split_node, torch.ops.aten.split_with_sizes.default):
|
|
continue
|
|
|
|
split_getitem_users = {}
|
|
for user in split_node.users:
|
|
if is_func(user, operator.getitem):
|
|
split_getitem_users[user.args[1]] = user
|
|
|
|
# Replace query node
|
|
split_getitem_users[0].replace_all_uses_with(getitem_nodes[1])
|
|
graph.erase_node(split_getitem_users[0])
|
|
# Replace key node
|
|
split_getitem_users[1].replace_all_uses_with(getitem_nodes[2])
|
|
graph.erase_node(split_getitem_users[1])
|
|
# Redirect value node to original qkv tensor
|
|
split_getitem_users[2].replace_input_with(split_node, query.args[0])
|
|
|
|
# Erase unused nodes
|
|
graph.erase_node(split_node)
|
|
graph.erase_node(slice_scatter_2_node)
|
|
graph.erase_node(slice_scatter_1_node)
|
|
|
|
count += 1
|
|
|
|
logger.debug("Eliminated %d slice_scatter+split nodes", count)
|