[Lint]Style: Convert vllm-ascend/compilation to ruff format (#5912)

### What this PR does / why we need it?
Convert `vllm-ascend/compilation` to ruff format.

### Does this PR introduce _any_ user-facing change?
During this migration, we encountered some **errors** in our CI and
testing environments, such as:
```
vllm_ascend/utils.py:653: in <module>
    def register_ascend_customop(vllm_config: VllmConfig | None = None):
                                              ^^^^^^^^^^^^^^^^^
E   TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType'
```

**1. Root Cause Analysis:**
The project uses a common pattern to break circular dependencies:
```python
if TYPE_CHECKING:
    from vllm.config import VllmConfig
else:
    VllmConfig = None  # Placeholder assigned at runtime
```
When Python parses the function definition `def
register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts
to evaluate the expression `VllmConfig | None`.
Since `VllmConfig` is assigned `None` at runtime, the expression
effectively becomes `None | None`. In Python, `None` is an instance of
`NoneType`. While the `|` operator is implemented for Type objects
(classes), it is not supported for `NoneType` instances, leading to the
`TypeError` shown above.

**2. Solution:**
To maintain the modern `|` syntax required by our new linting standards
while preserving our dependency management strategy, I have introduced:
```python
from __future__ import annotations
```
at the top of the affected files. This enables **Postponed Evaluation of
Annotations (PEP 563)**.

**3. Impact and Benefits:**
- By enabling `annotations`, Python no longer executes the `VllmConfig |
None` operation during module load. Instead, it stores the annotation as
a string literal, completely avoiding the `None | None` calculation.
- We can keep the `VllmConfig = None` placeholders. This ensures that
other modules can still import these symbols without triggering an
`ImportError`, maintaining a stable dependency graph.
- IDEs and static type checkers (MyPy/Pyright) continue to resolve the
types correctly. This allows us to use modern syntax without sacrificing
type safety or runtime stability.
- The only side effect is that `__annotations__` will now return strings
instead of type objects. Since this module does not use runtime type
enforcement or reflection, this change has zero negative impact on
existing functionality.
### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
11b6af5280

---------

Signed-off-by: MrZ20 <2609716663@qq.com>
This commit is contained in:
SILONG ZENG
2026-01-16 20:57:46 +08:00
committed by GitHub
parent 3af91e5ac4
commit 52086394ae
16 changed files with 996 additions and 1140 deletions

View File

@@ -25,7 +25,6 @@ from vllm.logger import logger
class AddRMSNormQuantPattern:
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
self.vllm_config = vllm_config
self.dtype = vllm_config.model_config.dtype
@@ -41,50 +40,48 @@ class AddRMSNormQuantPattern:
scale = torch.ones(4, device="npu", dtype=self.dtype)
scale_reciprocal = torch.ones(4, device="npu", dtype=self.dtype)
offset = torch.zeros(4, device="npu", dtype=self.dtype)
return [
rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal,
offset
]
return [rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal, offset]
def register(self, pm_pass: PatternMatcherPass):
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
scale_reciprocal: torch.Tensor, offset: torch.Tensor):
def pattern(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
scale_reciprocal: torch.Tensor,
offset: torch.Tensor,
):
"""
Pattern for AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
rms_norm_weight, self.eps)
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps)
out0 = output[0]
out1 = output[2]
quantized_output = torch.ops.vllm.quantize(out0, scale,
scale_reciprocal,
offset)
quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset)
return quantized_output, out1
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
scale_reciprocal: torch.Tensor, offset: torch.Tensor):
def replacement(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
scale_reciprocal: torch.Tensor,
offset: torch.Tensor,
):
"""
Replacement for the AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm_quant(rms_norm_input,
residual,
rms_norm_weight,
scale,
offset,
epsilon=self.eps)
output = torch.ops.npu.npu_add_rms_norm_quant(
rms_norm_input, residual, rms_norm_weight, scale, offset, epsilon=self.eps
)
quantized_output = output[0]
out1 = output[2]
return quantized_output, out1
pm.register_replacement(pattern, replacement, self.get_inputs(),
pm.fwd_only, pm_pass)
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
class AddRMSNormQuantPatternWithBias:
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
self.vllm_config = vllm_config
self.dtype = vllm_config.model_config.dtype
@@ -101,54 +98,51 @@ class AddRMSNormQuantPatternWithBias:
scale = torch.ones(4, device="npu", dtype=self.dtype)
scale_reciprocal = torch.ones(4, device="npu", dtype=self.dtype)
offset = torch.zeros(4, device="npu", dtype=self.dtype)
return [
rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal,
offset, rmsnorm_bias
]
return [rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal, offset, rmsnorm_bias]
def register(self, pm_pass: PatternMatcherPass):
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
scale_reciprocal: torch.Tensor, offset: torch.Tensor,
bias: torch.Tensor):
def pattern(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
scale_reciprocal: torch.Tensor,
offset: torch.Tensor,
bias: torch.Tensor,
):
"""
Pattern for AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
rms_norm_weight, self.eps)
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps)
out0 = output[0]
out1 = output[2]
out0 = out0 + bias
quantized_output = torch.ops.vllm.quantize(out0, scale,
scale_reciprocal,
offset)
quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset)
return quantized_output, out1
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
scale_reciprocal: torch.Tensor, offset: torch.Tensor,
bias: torch.Tensor):
def replacement(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
scale_reciprocal: torch.Tensor,
offset: torch.Tensor,
bias: torch.Tensor,
):
"""
Replacement for the AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm_quant(rms_norm_input,
residual,
rms_norm_weight,
scale,
offset,
epsilon=self.eps,
beta=bias)
output = torch.ops.npu.npu_add_rms_norm_quant(
rms_norm_input, residual, rms_norm_weight, scale, offset, epsilon=self.eps, beta=bias
)
quantized_output = output[0]
out1 = output[2]
return quantized_output, out1
pm.register_replacement(pattern, replacement, self.get_inputs(),
pm.fwd_only, pm_pass)
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
class AddRMSNormQuantSPPattern:
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
self.vllm_config = vllm_config
self.dtype = vllm_config.model_config.dtype
@@ -164,53 +158,50 @@ class AddRMSNormQuantSPPattern:
scale = torch.ones(4, device="npu", dtype=self.dtype)
scale_reciprocal = torch.ones(4, device="npu", dtype=self.dtype)
offset = torch.zeros(4, device="npu", dtype=self.dtype)
return [
rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal,
offset
]
return [rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal, offset]
def register(self, pm_pass: PatternMatcherPass):
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
scale_reciprocal: torch.Tensor, offset: torch.Tensor):
def pattern(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
scale_reciprocal: torch.Tensor,
offset: torch.Tensor,
):
"""
Pattern for AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
rms_norm_weight, self.eps)
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps)
out0 = output[0]
out1 = output[2]
out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
quantized_output = torch.ops.vllm.quantize(out0, scale,
scale_reciprocal,
offset)
quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset)
return quantized_output, out1
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
scale_reciprocal: torch.Tensor, offset: torch.Tensor):
def replacement(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
scale_reciprocal: torch.Tensor,
offset: torch.Tensor,
):
"""
Replacement for the AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm_quant(rms_norm_input,
residual,
rms_norm_weight,
scale,
offset,
epsilon=self.eps)
output = torch.ops.npu.npu_add_rms_norm_quant(
rms_norm_input, residual, rms_norm_weight, scale, offset, epsilon=self.eps
)
quantized_output = output[0]
out1 = output[2]
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
quantized_output, True)
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(quantized_output, True)
return quantized_output, out1
pm.register_replacement(pattern, replacement, self.get_inputs(),
pm.fwd_only, pm_pass)
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
class AddRMSNormQuantSPPatternWithBias:
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
self.vllm_config = vllm_config
self.dtype = vllm_config.model_config.dtype
@@ -227,53 +218,50 @@ class AddRMSNormQuantSPPatternWithBias:
scale = torch.ones(4, device="npu", dtype=self.dtype)
scale_reciprocal = torch.ones(4, device="npu", dtype=self.dtype)
offset = torch.zeros(4, device="npu", dtype=self.dtype)
return [
rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal,
offset, rmsnorm_bias
]
return [rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal, offset, rmsnorm_bias]
def register(self, pm_pass: PatternMatcherPass):
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
scale_reciprocal: torch.Tensor, offset: torch.Tensor,
bias: torch.Tensor):
def pattern(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
scale_reciprocal: torch.Tensor,
offset: torch.Tensor,
bias: torch.Tensor,
):
"""
Pattern for AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
rms_norm_weight, self.eps)
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps)
out0 = output[0]
out1 = output[2]
out0 = out0 + bias
out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
quantized_output = torch.ops.vllm.quantize(out0, scale,
scale_reciprocal,
offset)
quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset)
return quantized_output, out1
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
scale_reciprocal: torch.Tensor, offset: torch.Tensor,
bias: torch.Tensor):
def replacement(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
scale_reciprocal: torch.Tensor,
offset: torch.Tensor,
bias: torch.Tensor,
):
"""
Replacement for the AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm_quant(rms_norm_input,
residual,
rms_norm_weight,
scale,
offset,
epsilon=self.eps,
beta=bias)
output = torch.ops.npu.npu_add_rms_norm_quant(
rms_norm_input, residual, rms_norm_weight, scale, offset, epsilon=self.eps, beta=bias
)
quantized_output = output[0]
out1 = output[2]
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
quantized_output, True)
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(quantized_output, True)
return quantized_output, out1
pm.register_replacement(pattern, replacement, self.get_inputs(),
pm.fwd_only, pm_pass)
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
class AddRMSNormQuantFusionPass(VllmInductorPass):
@@ -283,25 +271,19 @@ class AddRMSNormQuantFusionPass(VllmInductorPass):
def __init__(self, vllm_config: VllmConfig):
super().__init__(vllm_config)
self.pattern_match_passes: PatternMatcherPass = PatternMatcherPass(
pass_name="rmsnorm_quant_fusion_pass")
self.pattern_match_passes: PatternMatcherPass = PatternMatcherPass(pass_name="rmsnorm_quant_fusion_pass")
dtype = vllm_config.model_config.dtype
if dtype not in (torch.bfloat16, torch.float16):
logger.debug("Quant fusion not enabled: unsupported dtype %s",
dtype)
logger.debug("Quant fusion not enabled: unsupported dtype %s", dtype)
return
common_epsilons = [1e-5, 1e-6]
for eps in common_epsilons:
AddRMSNormQuantPattern(vllm_config,
eps=eps).register(self.pattern_match_passes)
AddRMSNormQuantPatternWithBias(vllm_config, eps=eps).register(
self.pattern_match_passes)
AddRMSNormQuantSPPattern(vllm_config, eps=eps).register(
self.pattern_match_passes)
AddRMSNormQuantSPPatternWithBias(vllm_config, eps=eps).register(
self.pattern_match_passes)
AddRMSNormQuantPattern(vllm_config, eps=eps).register(self.pattern_match_passes)
AddRMSNormQuantPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes)
AddRMSNormQuantSPPattern(vllm_config, eps=eps).register(self.pattern_match_passes)
AddRMSNormQuantSPPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes)
def __call__(self, graph: torch.fx.Graph):
self.begin()