update
This commit is contained in:
138
vllm/compilation/passes/utility/scatter_split_replace.py
Normal file
138
vllm/compilation/passes/utility/scatter_split_replace.py
Normal file
@@ -0,0 +1,138 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Replace ``slice_scatter`` and ``split_with_sizes`` nodes with a single
|
||||
assignment if there are no users for the inplace tensor written to by
|
||||
the slice_scatter call.
|
||||
|
||||
The inplace rotary_embedding custom op takes in mutable query and key inputs
|
||||
that are split+getitem outputs of a single qkv tensor.
|
||||
When functionalized, we fetch the rotated query and key from the functionalized op
|
||||
using `getitem` calls. However, we also write to the qkv tensor inplace using a
|
||||
`slice_scatter`, then split the inplace tensor to get the output tensors again.
|
||||
Instead, if the inplace tensor has no subsequent users, we can just replace the
|
||||
`slice_scatter` and `split_with_sizes` nodes with the `getitem` calls.
|
||||
|
||||
This is already done in fix_functionalization::FixFunctionalizationPass, but
|
||||
writing a custom pass for it before defunctionalization allows matching against the
|
||||
qkv split+rotary_embedding subpattern as part of e.g. the RoPE+KVCache fusion pass.
|
||||
"""
|
||||
|
||||
import operator
|
||||
|
||||
import torch
|
||||
from torch import fx
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from ..fx_utils import is_func
|
||||
from ..vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ScatterSplitReplacementPass(VllmInductorPass):
|
||||
"""Replace getitem+slice_scatter+split nodes with a single getitem when
|
||||
the inplace subtensor written to by the slice_scatter has no other users.
|
||||
|
||||
Here's an example graph with q_size = 512, kv_size = 64:
|
||||
split_with_sizes_1 = torch.ops.aten.split_with_sizes.default(qkv, (512, 64, 64), -1)
|
||||
at = auto_functionalized(torch.ops._C.rotary_embedding.default(positions, q, k))
|
||||
q = operator.getitem(at, 1)
|
||||
k = operator.getitem(at, 2)
|
||||
torch.ops.aten.slice_scatter.default(qkv, q, [0, 512], -1)
|
||||
torch.ops.aten.slice_scatter.default(qkv, k, [512, 512 + 64], -1)
|
||||
split_with_sizes_2 = torch.ops.aten.split_with_sizes.default(qkv, (512, 64, 64), -1)
|
||||
q = operator.getitem(split_with_sizes_2, 0)
|
||||
k = operator.getitem(split_with_sizes_2, 1)
|
||||
v = operator.getitem(split_with_sizes_2, 2)
|
||||
|
||||
After this pass, this sequence of nodes is replaced with:
|
||||
split_with_sizes_1 = torch.ops.aten.split_with_sizes.default(qkv, (512, 64, 64), -1)
|
||||
at = auto_functionalized(torch.ops._C.rotary_embedding.default(positions, q, k))
|
||||
q = operator.getitem(at, 1)
|
||||
k = operator.getitem(at, 2)
|
||||
v = operator.getitem(split_with_sizes_1, 2)
|
||||
"""
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
count = 0
|
||||
|
||||
target_ops = [torch.ops._C.rotary_embedding.default]
|
||||
if hasattr(torch.ops.vllm, "rocm_aiter_triton_rotary_embedding"):
|
||||
target_ops.append(torch.ops.vllm.rocm_aiter_triton_rotary_embedding.default)
|
||||
|
||||
for node in graph.nodes:
|
||||
if not is_func(node, auto_functionalized):
|
||||
continue
|
||||
|
||||
kwargs = node.kwargs
|
||||
at_target = node.args[0]
|
||||
|
||||
if at_target in target_ops:
|
||||
query = kwargs["query"]
|
||||
key = kwargs["key"]
|
||||
getitem_nodes = {}
|
||||
for user in node.users:
|
||||
if is_func(user, operator.getitem):
|
||||
getitem_nodes[user.args[1]] = user
|
||||
|
||||
if (
|
||||
is_func(query, operator.getitem)
|
||||
and is_func(key, operator.getitem)
|
||||
and query.args[0] == key.args[0]
|
||||
and is_func(query.args[0], torch.ops.aten.split_with_sizes.default)
|
||||
and all(
|
||||
is_func(user, torch.ops.aten.slice_scatter.default)
|
||||
for getitem_node in getitem_nodes.values()
|
||||
for user in getitem_node.users
|
||||
)
|
||||
):
|
||||
# Pattern where query and key are slices of a qkv tensor.
|
||||
# While functionalized, results at [1] and [2] are scattered
|
||||
# back into qkv, then split again to get query and key.
|
||||
# If the inplace tensor has no other users, we can replace
|
||||
# the slice_scatter+split nodes with the original results.
|
||||
for user in getitem_nodes[1].users:
|
||||
slice_scatter_1_node = user
|
||||
if not is_func(
|
||||
slice_scatter_1_node, torch.ops.aten.slice_scatter.default
|
||||
):
|
||||
continue
|
||||
|
||||
for user in getitem_nodes[2].users:
|
||||
slice_scatter_2_node = user
|
||||
if not is_func(
|
||||
slice_scatter_2_node, torch.ops.aten.slice_scatter.default
|
||||
):
|
||||
continue
|
||||
|
||||
for user in slice_scatter_2_node.users:
|
||||
split_node = user
|
||||
if not is_func(split_node, torch.ops.aten.split_with_sizes.default):
|
||||
continue
|
||||
|
||||
split_getitem_users = {}
|
||||
for user in split_node.users:
|
||||
if is_func(user, operator.getitem):
|
||||
split_getitem_users[user.args[1]] = user
|
||||
|
||||
# Replace query node
|
||||
split_getitem_users[0].replace_all_uses_with(getitem_nodes[1])
|
||||
graph.erase_node(split_getitem_users[0])
|
||||
# Replace key node
|
||||
split_getitem_users[1].replace_all_uses_with(getitem_nodes[2])
|
||||
graph.erase_node(split_getitem_users[1])
|
||||
# Redirect value node to original qkv tensor
|
||||
split_getitem_users[2].replace_input_with(split_node, query.args[0])
|
||||
|
||||
# Erase unused nodes
|
||||
graph.erase_node(split_node)
|
||||
graph.erase_node(slice_scatter_2_node)
|
||||
graph.erase_node(slice_scatter_1_node)
|
||||
|
||||
count += 1
|
||||
|
||||
logger.debug("Eliminated %d slice_scatter+split nodes", count)
|
||||
Reference in New Issue
Block a user