Files
xc-llm-ascend/vllm_ascend/compilation/compiler_interface.py
CodeCat 1402cf6874 [Graph][Fusion] Add QKVNormRope and QKVNormRopeWithBias (#5721)
### What this PR does / why we need it?
This PR builds upon PR
https://github.com/vllm-project/vllm-ascend/pull/5011 and aims to
further enhance the npu_graph_ex_passes module. Based on prior work, we
have added graph optimization support for the add_rms_quant fused
operator in scenarios where a bias term is present—ensuring the fusion
pattern is correctly registered and matched into the computation graph.

For validation, we switched to the Qwen3-235B-A22B-W8A8 model for
QKVNormRopeWithBias and Qwen3-32B model for QKVNormRope . Benchmark
results show that, compared to the unfused baseline, enabling this
fusion pass significantly improves inference throughput for W8A8
quantized models.
For more details can refer to the
RFC:https://github.com/vllm-project/vllm-ascend/issues/4715
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
```
llm = LLM(
        model=model,
        tensor_parallel_size=GPUs_per_dp_rank,
        enforce_eager=False,
        enable_expert_parallel=enable_expert_parallel,
        trust_remote_code=trust_remote_code,
        gpu_memory_utilization=0.98,
        max_num_batched_tokens=512,
        # load_format="dummy",
        max_model_len=2048,
        max_num_seqs=16,
        quantization="ascend",
        additional_config={
            "refresh": True,
            "enable_npugraph_ex": True
        },
        compilation_config={
            "cudagraph_capture_sizes": [8, 16],
            "cudagraph_mode": "FULL_DECODE_ONLY",
        },
    )
    if profile_dir:
        llm.start_profile()
    outputs = llm.generate(prompts, sampling_params)
    if profile_dir:
        llm.stop_profile()
    for i, output in enumerate(outputs):
        if i >= 5:
            break
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(
            f"DP rank {global_dp_rank}, Prompt: {prompt!r}, "
            f"Generated text: {generated_text!r}"
        )
```
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef

---------

Signed-off-by: cjian <2318164299@qq.com>
2026-01-22 17:22:41 +08:00

140 lines
5.4 KiB
Python

#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import functools
from collections.abc import Callable
from typing import Any
import torch
import torch.fx as fx
from torch._dynamo.backends.common import aot_autograd
from torch._inductor.compile_fx import graph_returns_tuple, make_graph_return_tuple
from torch._inductor.decomposition import select_decomp_table
from torch.fx import GraphModule
from vllm.compilation.compiler_interface import CompilerInterface
from vllm.config import VllmConfig
from vllm.config.utils import Range
from vllm_ascend.ascend_config import NpugraphExConfig, get_ascend_config
from vllm_ascend.utils import COMPILATION_PASS_KEY
def compile_fx(graph: GraphModule, example_inputs: list, inner_compile: Callable, decompositions: dict) -> Callable:
recursive_compile_fx = functools.partial(compile_fx, inner_compile=inner_compile, decompositions=decompositions)
if not graph_returns_tuple(graph):
return make_graph_return_tuple(graph, example_inputs, recursive_compile_fx)
return aot_autograd(fw_compiler=inner_compile)(graph, example_inputs)
def fusion_pass_compile(
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
compile_range: Range,
key: str | None = None,
) -> tuple[Callable | None, Any | None]:
def compile_inner(graph, example_inputs):
current_pass_manager = compiler_config[COMPILATION_PASS_KEY]
graph = current_pass_manager(graph)
return graph
decompositions = select_decomp_table()
compiled_fn = compile_fx(
graph=graph,
example_inputs=example_inputs,
inner_compile=compile_inner,
decompositions=decompositions,
)
return compiled_fn, None
def npugraph_ex_compile(
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
vllm_config: VllmConfig,
npugraph_ex_config: NpugraphExConfig,
compile_range: Range,
key: str | None = None,
) -> tuple[Callable | None, Any | None]:
import torchair
torch.npu.set_compile_mode(jit_compile=False)
config = torchair.CompilerConfig()
# use aclgraph mode, avoid the transformation from fx graph to Ascend IR.
config.mode = "reduce-overhead"
# execute FX graph in eager mode before graph mode to optimize FX graph.
config.debug.run_eagerly = True
if npugraph_ex_config.enable_static_kernel:
config.experimental_config.aclgraph._aclnn_static_shape_kernel = True
# According to the cudagraph_capture_size configuration, set the shapes
# that can trigger the compilation of static kernel. If this configuration is
# not applied, new shapes will trigger the compilation of static kernels,
# affecting program execution.
num_spec_tokens = vllm_config.speculative_config.num_speculative_token if vllm_config.speculative_config else 0
uniform_decode_query_len = num_spec_tokens + 1
max_num_tokens = vllm_config.scheduler_config.max_num_seq * uniform_decode_query_len
decode_cudagraph_batch_sizes = [
x
for x in vllm_config.compilation_config.cudagraph_capture_size
if max_num_tokens >= x >= uniform_decode_query_len
]
config.experimental_config.aclgraph._aclnn_static_shape_kernel_sym_value_range = decode_cudagraph_batch_sizes
npugraph_ex = torchair.get_npu_backend(compiler_config=config)
# torch.compile requires the output of the fx graph to be a tuple
if not graph_returns_tuple(graph):
return make_graph_return_tuple(graph, example_inputs, npugraph_ex), None
return npugraph_ex(graph, example_inputs), None
class AscendCompiler(CompilerInterface):
"""
AscendCompiler is a custom compiler interface for the Ascend platform.
This class provides a method to compile a PyTorch FX graph module with
specific configurations for graph fusion and decomposition.
"""
name = "AscendCompiler"
def compute_hash(self, vllm_config: VllmConfig) -> str:
npugraph_ex_config = get_ascend_config().npugraph_ex_config
if npugraph_ex_config.enable:
self.vllm_config = vllm_config
return vllm_config.compute_hash()
def compile(
self,
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
compile_range: Range,
key: str | None = None,
) -> tuple[Callable | None, Any | None]:
npugraph_ex_config = get_ascend_config().npugraph_ex_config
if npugraph_ex_config.enable:
assert hasattr(self, "vllm_config")
return npugraph_ex_compile(
graph, example_inputs, compiler_config, self.vllm_config, npugraph_ex_config, compile_range, key
)
else:
return fusion_pass_compile(graph, example_inputs, compiler_config, compile_range, key)