Files
xc-llm-ascend/vllm_ascend/compilation/compiler_interface.py
SILONG ZENG 52086394ae [Lint]Style: Convert vllm-ascend/compilation to ruff format (#5912)
### What this PR does / why we need it?
Convert `vllm-ascend/compilation` to ruff format.

### Does this PR introduce _any_ user-facing change?
During this migration, we encountered some **errors** in our CI and
testing environments, such as:
```
vllm_ascend/utils.py:653: in <module>
    def register_ascend_customop(vllm_config: VllmConfig | None = None):
                                              ^^^^^^^^^^^^^^^^^
E   TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType'
```

**1. Root Cause Analysis:**
The project uses a common pattern to break circular dependencies:
```python
if TYPE_CHECKING:
    from vllm.config import VllmConfig
else:
    VllmConfig = None  # Placeholder assigned at runtime
```
When Python parses the function definition `def
register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts
to evaluate the expression `VllmConfig | None`.
Since `VllmConfig` is assigned `None` at runtime, the expression
effectively becomes `None | None`. In Python, `None` is an instance of
`NoneType`. While the `|` operator is implemented for Type objects
(classes), it is not supported for `NoneType` instances, leading to the
`TypeError` shown above.

**2. Solution:**
To maintain the modern `|` syntax required by our new linting standards
while preserving our dependency management strategy, I have introduced:
```python
from __future__ import annotations
```
at the top of the affected files. This enables **Postponed Evaluation of
Annotations (PEP 563)**.

**3. Impact and Benefits:**
- By enabling `annotations`, Python no longer executes the `VllmConfig |
None` operation during module load. Instead, it stores the annotation as
a string literal, completely avoiding the `None | None` calculation.
- We can keep the `VllmConfig = None` placeholders. This ensures that
other modules can still import these symbols without triggering an
`ImportError`, maintaining a stable dependency graph.
- IDEs and static type checkers (MyPy/Pyright) continue to resolve the
types correctly. This allows us to use modern syntax without sacrificing
type safety or runtime stability.
- The only side effect is that `__annotations__` will now return strings
instead of type objects. Since this module does not use runtime type
enforcement or reflection, this change has zero negative impact on
existing functionality.
### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
11b6af5280

---------

Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-16 20:57:46 +08:00

131 lines
4.8 KiB
Python

#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import functools
from collections.abc import Callable
from typing import Any
import torch
import torch.fx as fx
from torch._dynamo.backends.common import aot_autograd
from torch._inductor.compile_fx import graph_returns_tuple, make_graph_return_tuple
from torch._inductor.decomposition import select_decomp_table
from torch.fx import GraphModule
from vllm.compilation.compiler_interface import CompilerInterface
from vllm.config.utils import Range
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.utils import COMPILATION_PASS_KEY
def compile_fx(graph: GraphModule, example_inputs: list, inner_compile: Callable, decompositions: dict) -> Callable:
recursive_compile_fx = functools.partial(compile_fx, inner_compile=inner_compile, decompositions=decompositions)
if not graph_returns_tuple(graph):
return make_graph_return_tuple(graph, example_inputs, recursive_compile_fx)
return aot_autograd(fw_compiler=inner_compile)(graph, example_inputs)
def fusion_pass_compile(
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
compile_range: Range,
key: str | None = None,
) -> tuple[Callable | None, Any | None]:
def compile_inner(graph, example_inputs):
current_pass_manager = compiler_config[COMPILATION_PASS_KEY]
graph = current_pass_manager(graph)
return graph
decompositions = select_decomp_table()
compiled_fn = compile_fx(
graph=graph,
example_inputs=example_inputs,
inner_compile=compile_inner,
decompositions=decompositions,
)
return compiled_fn, None
def npugraph_ex_compile(
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
compile_range: Range,
key: str | None = None,
) -> tuple[Callable | None, Any | None]:
# When currently using the FULL_DECODE_ONLY mode,
# the piecewise compilation level slicing process
# in vllm is also encountered.
# This process causes the output to no longer be
# wrapped as a tuple when the fx graph has a single
# output, but torch.compile has a mandatory check.
fx_graph = graph.graph
if not graph_returns_tuple(graph):
output_node = fx_graph.output_node()
with fx_graph.inserting_before(output_node):
return_value = output_node.args[0]
tuple_node = fx_graph.create_node("call_function", tuple, args=([return_value],))
output_node.args = (tuple_node,)
graph.recompile()
import torchair
# TODO: use a better way to lazy register replacement, instead of import one by one
# As an example, we directly import here to register replacement.
# import vllm_ascend.compilation.npugraph_ex_passes.add_rms_norm_quant # noqa
torch.npu.set_compile_mode(jit_compile=False)
config = torchair.CompilerConfig()
# use aclgraph mode, avoid the transformation from fx graph to Ascend IR.
config.mode = "reduce-overhead"
# execute FX graph in eager mode before graph mode to optimize FX graph.
config.debug.run_eagerly = True
# static kernel switch, suitable for static shapes or scenes with less shape changes.
config.experimental_config.aclgraph._aclnn_static_shape_kernel = True
npugraph_ex = torchair.get_npu_backend(compiler_config=config)
compile_graph = npugraph_ex(graph, example_inputs)
return compile_graph, None
class AscendCompiler(CompilerInterface):
"""
AscendCompiler is a custom compiler interface for the Ascend platform.
This class provides a method to compile a PyTorch FX graph module with
specific configurations for graph fusion and decomposition.
"""
name = "AscendCompiler"
def compile(
self,
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
compile_range: Range,
key: str | None = None,
) -> tuple[Callable | None, Any | None]:
ascend_config = get_ascend_config()
if ascend_config.enable_npugraph_ex:
return npugraph_ex_compile(graph, example_inputs, compiler_config, compile_range, key)
else:
return fusion_pass_compile(graph, example_inputs, compiler_config, compile_range, key)