[Lint]Style: Convert vllm-ascend/compilation to ruff format (#5912)

### What this PR does / why we need it?
Convert `vllm-ascend/compilation` to ruff format.

### Does this PR introduce _any_ user-facing change?
During this migration, we encountered some **errors** in our CI and
testing environments, such as:
```
vllm_ascend/utils.py:653: in <module>
    def register_ascend_customop(vllm_config: VllmConfig | None = None):
                                              ^^^^^^^^^^^^^^^^^
E   TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType'
```

**1. Root Cause Analysis:**
The project uses a common pattern to break circular dependencies:
```python
if TYPE_CHECKING:
    from vllm.config import VllmConfig
else:
    VllmConfig = None  # Placeholder assigned at runtime
```
When Python parses the function definition `def
register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts
to evaluate the expression `VllmConfig | None`.
Since `VllmConfig` is assigned `None` at runtime, the expression
effectively becomes `None | None`. In Python, `None` is an instance of
`NoneType`. While the `|` operator is implemented for Type objects
(classes), it is not supported for `NoneType` instances, leading to the
`TypeError` shown above.

**2. Solution:**
To maintain the modern `|` syntax required by our new linting standards
while preserving our dependency management strategy, I have introduced:
```python
from __future__ import annotations
```
at the top of the affected files. This enables **Postponed Evaluation of
Annotations (PEP 563)**.

**3. Impact and Benefits:**
- By enabling `annotations`, Python no longer executes the `VllmConfig |
None` operation during module load. Instead, it stores the annotation as
a string literal, completely avoiding the `None | None` calculation.
- We can keep the `VllmConfig = None` placeholders. This ensures that
other modules can still import these symbols without triggering an
`ImportError`, maintaining a stable dependency graph.
- IDEs and static type checkers (MyPy/Pyright) continue to resolve the
types correctly. This allows us to use modern syntax without sacrificing
type safety or runtime stability.
- The only side effect is that `__annotations__` will now return strings
instead of type objects. Since this module does not use runtime type
enforcement or reflection, this change has zero negative impact on
existing functionality.
### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
11b6af5280

---------

Signed-off-by: MrZ20 <2609716663@qq.com>
This commit is contained in:
SILONG ZENG
2026-01-16 20:57:46 +08:00
committed by GitHub
parent 3af91e5ac4
commit 52086394ae
16 changed files with 996 additions and 1140 deletions

View File

@@ -48,7 +48,8 @@ def _extra_stream_scope_check(match: Match) -> bool:
logger.debug(
f"Cross-stream operation detected in pattern match for AddRMSNormQuant. "
f"Multiple streams found: {non_default_streams}. "
f"Fusion is not supported for cross-stream operations.")
f"Fusion is not supported for cross-stream operations."
)
return False
return True
@@ -57,24 +58,29 @@ def _extra_stream_scope_check(match: Match) -> bool:
@functools.lru_cache(None)
# The replacement registered here will be actually executed after AOT.
def replacement_add_rms_norm_quant(epsilon):
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
offset: torch.Tensor):
def pattern(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
offset: torch.Tensor,
):
"""
Pattern for AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
rms_norm_weight, epsilon)
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, epsilon)
out0 = output[0]
out1 = output[2]
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset,
torch.qint8, -1, False)
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, torch.qint8, -1, False)
return quantized_output, out1
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
offset: torch.Tensor):
def replacement(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
offset: torch.Tensor,
):
"""
Replacement for the AddRMSNormQuant fusion.
"""
@@ -82,10 +88,12 @@ def replacement_add_rms_norm_quant(epsilon):
rms_norm_input,
residual,
rms_norm_weight,
# The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel.
1. / scale,
# The inverse of scale is required by npu_add_rms_norm_quant kernel
# which is opposite to the npu_quantize kernel.
1.0 / scale,
offset,
epsilon=epsilon)
epsilon=epsilon,
)
quantized_output = output[0]
out1 = output[2]
return quantized_output, out1
@@ -103,33 +111,39 @@ def replacement_add_rms_norm_quant(epsilon):
import torchair
torchair.register_replacement(search_fn=pattern,
replace_fn=replacement,
example_inputs=get_inputs(),
extra_check=_extra_stream_scope_check)
torchair.register_replacement(
search_fn=pattern, replace_fn=replacement, example_inputs=get_inputs(), extra_check=_extra_stream_scope_check
)
# The replacement registered here will be actually executed after AOT.
def replacement_add_rms_norm_quant_with_bias(epsilon):
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
offset: torch.Tensor, bias: torch.Tensor):
def pattern(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
offset: torch.Tensor,
bias: torch.Tensor,
):
"""
Pattern for AddRMSNormQuantWithBias fusion.
"""
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
rms_norm_weight, epsilon)
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, epsilon)
out0 = output[0]
out1 = output[2]
out0 = out0 + bias
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset,
torch.qint8, -1, False)
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, torch.qint8, -1, False)
return quantized_output, out1
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
offset: torch.Tensor, bias: torch.Tensor):
def replacement(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
offset: torch.Tensor,
bias: torch.Tensor,
):
"""
Replacement for AddRMSNormQuantWithBias fusion.
"""
@@ -137,11 +151,13 @@ def replacement_add_rms_norm_quant_with_bias(epsilon):
rms_norm_input,
residual,
rms_norm_weight,
# The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel.
1. / scale,
# The inverse of scale is required by npu_add_rms_norm_quant kernel
# which is opposite to the npu_quantize kernel.
1.0 / scale,
offset,
epsilon=epsilon,
beta=bias)
beta=bias,
)
quantized_output = output[0]
out1 = output[2]
return quantized_output, out1
@@ -156,40 +172,41 @@ def replacement_add_rms_norm_quant_with_bias(epsilon):
rmsnorm_bias = torch.randn(4, device="npu")
scale = torch.ones(4, device="npu")
offset = torch.zeros(4, device="npu")
return [
rms_norm_input, residual, rms_norm_weight, scale, offset,
rmsnorm_bias
]
return [rms_norm_input, residual, rms_norm_weight, scale, offset, rmsnorm_bias]
import torchair
torchair.register_replacement(search_fn=pattern,
replace_fn=replacement,
example_inputs=get_inputs(),
extra_check=_extra_stream_scope_check)
torchair.register_replacement(
search_fn=pattern, replace_fn=replacement, example_inputs=get_inputs(), extra_check=_extra_stream_scope_check
)
# The replacement registered here will be actually executed after AOT.
def replacement_add_rms_norm_quant_sp_pattern(epsilon):
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
offset: torch.Tensor):
def pattern(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
offset: torch.Tensor,
):
"""
Pattern for AddRMSNormQuantSPPattern fusion.
"""
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
rms_norm_weight, epsilon)
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, epsilon)
out0 = output[0]
out1 = output[2]
out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset,
torch.qint8, -1, False)
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, torch.qint8, -1, False)
return quantized_output, out1
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
offset: torch.Tensor):
def replacement(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
offset: torch.Tensor,
):
"""
Replacement for the AddRMSNormQuantSPPattern fusion.
"""
@@ -197,14 +214,15 @@ def replacement_add_rms_norm_quant_sp_pattern(epsilon):
rms_norm_input,
residual,
rms_norm_weight,
# The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel.
1. / scale,
# The inverse of scale is required by npu_add_rms_norm_quant kernel
# which is opposite to the npu_quantize kernel.
1.0 / scale,
offset,
epsilon=epsilon)
epsilon=epsilon,
)
quantized_output = output[0]
out1 = output[2]
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
quantized_output, True)
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(quantized_output, True)
return quantized_output, out1
def get_inputs():
@@ -220,34 +238,40 @@ def replacement_add_rms_norm_quant_sp_pattern(epsilon):
import torchair
torchair.register_replacement(search_fn=pattern,
replace_fn=replacement,
example_inputs=get_inputs(),
extra_check=_extra_stream_scope_check)
torchair.register_replacement(
search_fn=pattern, replace_fn=replacement, example_inputs=get_inputs(), extra_check=_extra_stream_scope_check
)
# The replacement registered here will be actually executed after AOT.
def replacement_add_rms_norm_quant_sp_pattern_with_bias(epsilon):
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
offset: torch.Tensor, bias: torch.Tensor):
def pattern(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
offset: torch.Tensor,
bias: torch.Tensor,
):
"""
Pattern for AddRMSNormQuantSPPatternWithBias fusion.
"""
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
rms_norm_weight, epsilon)
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, epsilon)
out0 = output[0]
out1 = output[2]
out0 = out0 + bias
out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset,
torch.qint8, -1, False)
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, torch.qint8, -1, False)
return quantized_output, out1
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
offset: torch.Tensor, bias: torch.Tensor):
def replacement(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
offset: torch.Tensor,
bias: torch.Tensor,
):
"""
Replacement for the AddRMSNormQuantSPPatternWithBias fusion.
"""
@@ -255,15 +279,16 @@ def replacement_add_rms_norm_quant_sp_pattern_with_bias(epsilon):
rms_norm_input,
residual,
rms_norm_weight,
# The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel.
1. / scale,
# The inverse of scale is required by npu_add_rms_norm_quant kernel
# which is opposite to the npu_quantize kernel.
1.0 / scale,
offset,
epsilon=epsilon,
beta=bias)
beta=bias,
)
quantized_output = output[0]
out1 = output[2]
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
quantized_output, True)
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(quantized_output, True)
return quantized_output, out1
def get_inputs():
@@ -276,25 +301,19 @@ def replacement_add_rms_norm_quant_sp_pattern_with_bias(epsilon):
rmsnorm_bias = torch.randn(4, device="npu")
scale = torch.ones(4, device="npu")
offset = torch.zeros(4, device="npu")
return [
rms_norm_input, residual, rms_norm_weight, scale, offset,
rmsnorm_bias
]
return [rms_norm_input, residual, rms_norm_weight, scale, offset, rmsnorm_bias]
import torchair
torchair.register_replacement(search_fn=pattern,
replace_fn=replacement,
example_inputs=get_inputs(),
extra_check=_extra_stream_scope_check)
torchair.register_replacement(
search_fn=pattern, replace_fn=replacement, example_inputs=get_inputs(), extra_check=_extra_stream_scope_check
)
# register converter for pass
common_epsilons = [1e-5, 1e-6]
for eps in common_epsilons:
logger.info(
f"Start register fusion pattern for AddRMSNormQuant with epsilons={eps}"
)
logger.info(f"Start register fusion pattern for AddRMSNormQuant with epsilons={eps}")
replacement_add_rms_norm_quant(eps)
replacement_add_rms_norm_quant_with_bias(eps)
replacement_add_rms_norm_quant_sp_pattern(eps)