### What this PR does / why we need it?
We introduced the npugraph_ex backend through the vllm's adaptor
dispatch mechanism to accelerate aclgraph. This solution is based on
torch.compile and uses torchair to optimize the fx.graph. The
performance gains are mainly obtained from the static kernel. We
conducted tests on Qwen3-30B and achieved over 5% performance
optimization.
### Does this PR introduce _any_ user-facing change?
Yes, we add a new switch named"enable_npugraph_ex" in additional_config,
default is False.
We also add an example to show how to register custom replacement pass
### More information about this PR
This feature depends on the release of CANN and torch_npu in Q4.
We tested it on a package that has not been publicly released yet and
verified that the functionality works.
This feature is still experimental at the moment; setting the config
true will directly raise error.
Merging into the main branch initially involves some preliminary commits
to facilitate subsequent development and testing of the feature, as well
as to avoid submitting an excessively large PR at once.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: chencangtao <chencangtao@huawei.com>
Signed-off-by: ChenCangtao <50493711+ChenCangtao@users.noreply.github.com>
Co-authored-by: chencangtao <chencangtao@huawei.com>
Co-authored-by: panchao-hub <315134829@qq.com>
Co-authored-by: wbigat <wbigat@163.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
138 lines
5.2 KiB
Python
138 lines
5.2 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# This file is a part of the vllm-ascend project.
|
|
#
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
import functools
|
|
from typing import Any, Callable, Optional
|
|
|
|
import torch
|
|
import torch.fx as fx
|
|
from torch._dynamo.backends.common import aot_autograd
|
|
from torch._inductor.compile_fx import (graph_returns_tuple,
|
|
make_graph_return_tuple)
|
|
from torch._inductor.decomposition import select_decomp_table
|
|
from torch.fx import GraphModule
|
|
from vllm.compilation.compiler_interface import CompilerInterface
|
|
|
|
from vllm_ascend.ascend_config import get_ascend_config
|
|
|
|
|
|
def compile_fx(graph: GraphModule, example_inputs: list,
|
|
inner_compile: Callable, decompositions: dict) -> Callable:
|
|
recursive_compile_fx = functools.partial(compile_fx,
|
|
inner_compile=inner_compile,
|
|
decompositions=decompositions)
|
|
|
|
if not graph_returns_tuple(graph):
|
|
return make_graph_return_tuple(graph, example_inputs,
|
|
recursive_compile_fx)
|
|
return aot_autograd(fw_compiler=inner_compile)(graph, example_inputs)
|
|
|
|
|
|
def fusion_pass_compile(
|
|
graph: fx.GraphModule,
|
|
example_inputs: list[Any],
|
|
compiler_config: dict[str, Any],
|
|
runtime_shape: Optional[int] = None,
|
|
key: Optional[str] = None,
|
|
) -> tuple[Optional[Callable], Optional[Any]]:
|
|
|
|
def compile_inner(graph, example_inputs):
|
|
current_pass_manager = compiler_config["graph_fusion_manager"]
|
|
graph = current_pass_manager(graph, runtime_shape)
|
|
return graph
|
|
|
|
decompositions = select_decomp_table()
|
|
|
|
compiled_fn = compile_fx(
|
|
graph=graph,
|
|
example_inputs=example_inputs,
|
|
inner_compile=compile_inner,
|
|
decompositions=decompositions,
|
|
)
|
|
|
|
return compiled_fn, None
|
|
|
|
|
|
def npugraph_ex_compile(
|
|
graph: fx.GraphModule,
|
|
example_inputs: list[Any],
|
|
compiler_config: dict[str, Any],
|
|
runtime_shape: Optional[int] = None,
|
|
key: Optional[str] = None,
|
|
) -> tuple[Optional[Callable], Optional[Any]]:
|
|
# When currently using the FULL_DECODE_ONLY mode,
|
|
# the piecewise compilation level slicing process
|
|
# in vllm is also encountered.
|
|
# This process causes the output to no longer be
|
|
# wrapped as a tuple when the fx graph has a single
|
|
# output, but torch.compile has a mandatory check.
|
|
fx_graph = graph.graph
|
|
if not graph_returns_tuple(graph):
|
|
output_node = fx_graph.output_node()
|
|
with fx_graph.inserting_before(output_node):
|
|
return_value = output_node.args[0]
|
|
tuple_node = fx_graph.create_node("call_function",
|
|
tuple,
|
|
args=([return_value], ))
|
|
output_node.args = (tuple_node, )
|
|
fx_graph.recompile()
|
|
|
|
import torchair
|
|
|
|
# TODO: use a better way to lazy register replacement, instead of import one by one
|
|
# As an example, we directly import here to register replacement.
|
|
import vllm_ascend.compilation.npugraph_ex_passes.add_rms_norm_quant # noqa
|
|
|
|
torch.npu.set_compile_mode(jit_compile=False)
|
|
config = torchair.CompilerConfig()
|
|
# use aclgraph mode, avoid the transformation from fx graph to Ascend IR.
|
|
config.mode = "reduce-overhead"
|
|
# execute FX graph in eager mode before graph mode to optimize FX graph.
|
|
config.debug.run_eagerly = True
|
|
# static kernel switch, suitable for static shapes or scenes with less shape changes.
|
|
config.experimental_config.aclgraph._aclnn_static_shape_kernel = True
|
|
|
|
npugraph_ex = torchair.get_npu_backend(compiler_config=config)
|
|
compile_graph = npugraph_ex(graph, example_inputs)
|
|
return compile_graph, None
|
|
|
|
|
|
class AscendCompiler(CompilerInterface):
|
|
"""
|
|
AscendCompiler is a custom compiler interface for the Ascend platform.
|
|
This class provides a method to compile a PyTorch FX graph module with
|
|
specific configurations for graph fusion and decomposition.
|
|
"""
|
|
name = "AscendCompiler"
|
|
|
|
def compile(
|
|
self,
|
|
graph: fx.GraphModule,
|
|
example_inputs: list[Any],
|
|
compiler_config: dict[str, Any],
|
|
runtime_shape: Optional[int] = None,
|
|
key: Optional[str] = None,
|
|
) -> tuple[Optional[Callable], Optional[Any]]:
|
|
|
|
ascend_config = get_ascend_config()
|
|
if ascend_config.enable_npugraph_ex:
|
|
return npugraph_ex_compile(graph, example_inputs, compiler_config,
|
|
runtime_shape, key)
|
|
else:
|
|
return fusion_pass_compile(graph, example_inputs, compiler_config,
|
|
runtime_shape, key)
|