[Lint]Style: Convert vllm-ascend/compilation to ruff format (#5912)

### What this PR does / why we need it?
Convert `vllm-ascend/compilation` to ruff format.

### Does this PR introduce _any_ user-facing change?
During this migration, we encountered some **errors** in our CI and
testing environments, such as:
```
vllm_ascend/utils.py:653: in <module>
    def register_ascend_customop(vllm_config: VllmConfig | None = None):
                                              ^^^^^^^^^^^^^^^^^
E   TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType'
```

**1. Root Cause Analysis:**
The project uses a common pattern to break circular dependencies:
```python
if TYPE_CHECKING:
    from vllm.config import VllmConfig
else:
    VllmConfig = None  # Placeholder assigned at runtime
```
When Python parses the function definition `def
register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts
to evaluate the expression `VllmConfig | None`.
Since `VllmConfig` is assigned `None` at runtime, the expression
effectively becomes `None | None`. In Python, `None` is an instance of
`NoneType`. While the `|` operator is implemented for Type objects
(classes), it is not supported for `NoneType` instances, leading to the
`TypeError` shown above.

**2. Solution:**
To maintain the modern `|` syntax required by our new linting standards
while preserving our dependency management strategy, I have introduced:
```python
from __future__ import annotations
```
at the top of the affected files. This enables **Postponed Evaluation of
Annotations (PEP 563)**.

**3. Impact and Benefits:**
- By enabling `annotations`, Python no longer executes the `VllmConfig |
None` operation during module load. Instead, it stores the annotation as
a string literal, completely avoiding the `None | None` calculation.
- We can keep the `VllmConfig = None` placeholders. This ensures that
other modules can still import these symbols without triggering an
`ImportError`, maintaining a stable dependency graph.
- IDEs and static type checkers (MyPy/Pyright) continue to resolve the
types correctly. This allows us to use modern syntax without sacrificing
type safety or runtime stability.
- The only side effect is that `__annotations__` will now return strings
instead of type objects. Since this module does not use runtime type
enforcement or reflection, this change has zero negative impact on
existing functionality.
### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
11b6af5280

---------

Signed-off-by: MrZ20 <2609716663@qq.com>
This commit is contained in:
SILONG ZENG
2026-01-16 20:57:46 +08:00
committed by GitHub
parent 3af91e5ac4
commit 52086394ae
16 changed files with 996 additions and 1140 deletions

View File

@@ -2,9 +2,10 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
from collections.abc import Callable
from contextlib import ExitStack
from dataclasses import dataclass
from typing import Any, Callable, Optional
from typing import Any
from unittest.mock import patch
import numpy as np
@@ -27,12 +28,12 @@ from ..utils import weak_ref_tensors
@dataclasses.dataclass
class ACLGraphEntry:
batch_descriptor: BatchDescriptor
aclgraph: Optional[torch.npu.NPUGraph] = None
output: Optional[Any] = None
aclgraph: torch.npu.NPUGraph | None = None
output: Any | None = None
# for aclgraph debugging, track the input addresses
# during capture, and check if they are the same during replay
input_addresses: Optional[list[int]] = None
input_addresses: list[int] | None = None
class ACLGraphWrapper:
@@ -60,11 +61,13 @@ class ACLGraphWrapper:
guaranteed when VLLM_LOGGING_LEVEL == "DEBUG".
"""
def __init__(self,
runnable: Callable,
vllm_config: VllmConfig,
runtime_mode: CUDAGraphMode,
cudagraph_options: Optional[CUDAGraphOptions] = None):
def __init__(
self,
runnable: Callable,
vllm_config: VllmConfig,
runtime_mode: CUDAGraphMode,
cudagraph_options: CUDAGraphOptions | None = None,
):
self.runnable = runnable
self.vllm_config = vllm_config
self.runtime_mode = runtime_mode
@@ -83,15 +86,13 @@ class ACLGraphWrapper:
self.aclgraph_options = cudagraph_options
# the entries for different batch descriptors that we need to capture
# aclgraphs for.
self.concrete_aclgraph_entries: dict[BatchDescriptor, ACLGraphEntry]\
= {}
self.concrete_aclgraph_entries: dict[BatchDescriptor, ACLGraphEntry] = {}
def __getattr__(self, key: str):
# allow accessing the attributes of the runnable.
if hasattr(self.runnable, key):
return getattr(self.runnable, key)
raise AttributeError(f"Attribute {key} not exists in the runnable of "
f"aclgraph wrapper: {self.runnable}")
raise AttributeError(f"Attribute {key} not exists in the runnable of aclgraph wrapper: {self.runnable}")
def unwrap(self) -> Callable:
# in case we need to access the original runnable.
@@ -102,8 +103,7 @@ class ACLGraphWrapper:
batch_descriptor = forward_context.batch_descriptor
aclgraph_runtime_mode = forward_context.cudagraph_runtime_mode
if aclgraph_runtime_mode == CUDAGraphMode.NONE or \
aclgraph_runtime_mode != self.runtime_mode:
if aclgraph_runtime_mode == CUDAGraphMode.NONE or aclgraph_runtime_mode != self.runtime_mode:
# CUDAGraphMode.NONE could mean the profile run, a warmup run, or
# running without aclgraphs.
# We do not trigger capture/replay if the runtime mode is not
@@ -114,8 +114,7 @@ class ACLGraphWrapper:
if batch_descriptor not in self.concrete_aclgraph_entries:
# create a new entry for this batch descriptor
self.concrete_aclgraph_entries[batch_descriptor] = \
ACLGraphEntry(batch_descriptor=batch_descriptor)
self.concrete_aclgraph_entries[batch_descriptor] = ACLGraphEntry(batch_descriptor=batch_descriptor)
entry = self.concrete_aclgraph_entries[batch_descriptor]
@@ -125,14 +124,11 @@ class ACLGraphWrapper:
# capturing is fast, we don't need to log it for every
# shape. E.g. we only log it for the first subgraph in
# piecewise mode.
logger.debug("Capturing a aclgraph on (%s,%s)",
self.runtime_mode.name, entry.batch_descriptor)
logger.debug("Capturing a aclgraph on (%s,%s)", self.runtime_mode.name, entry.batch_descriptor)
# validate that aclgraph capturing is legal at this point.
validate_cudagraph_capturing_enabled()
input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
input_addresses = [x.data_ptr() for x in args if isinstance(x, torch.Tensor)]
entry.input_addresses = input_addresses
aclgraph = torch.npu.NPUGraph()
@@ -145,8 +141,7 @@ class ACLGraphWrapper:
# therefore, we only run gc for the first graph,
# and disable gc for the rest of the graphs.
stack.enter_context(patch("gc.collect", lambda: None))
stack.enter_context(
patch("torch.npu.empty_cache", lambda: None))
stack.enter_context(patch("torch.npu.empty_cache", lambda: None))
# mind-exploding: carefully manage the reference and memory.
forward_context.capturing = True
@@ -183,13 +178,12 @@ class ACLGraphWrapper:
if self.is_debugging_mode:
# check if the input addresses are the same
new_input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
new_input_addresses = [x.data_ptr() for x in args if isinstance(x, torch.Tensor)]
assert new_input_addresses == entry.input_addresses, (
f"Input addresses for aclgraphs are different "
f"during replay. Expected {entry.input_addresses}, "
f"got {new_input_addresses}")
f"got {new_input_addresses}"
)
logger.info_once("Replaying aclgraph")
# In async scheduling or multi-threaded (MT) scenarios when graph mode is FULL, it is possible that
@@ -209,8 +203,7 @@ def weak_ref_workspaces(params):
for num_tokens in params.workspaces:
if params.workspaces[num_tokens] is None:
continue
params.workspaces[num_tokens] = weak_ref_tensors(
params.workspaces[num_tokens])
params.workspaces[num_tokens] = weak_ref_tensors(params.workspaces[num_tokens])
def _update_attn_pa_params(update_stream, forward_context, runtime_shape):
@@ -219,10 +212,10 @@ def _update_attn_pa_params(update_stream, forward_context, runtime_shape):
# for each layer's attention op in the graph.
with torch.npu.stream(update_stream):
for key, param, handle, event in zip(
forward_context.attn_metadata,
graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
forward_context.attn_metadata,
graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
):
(
query,
@@ -254,18 +247,21 @@ def _update_attn_pa_params(update_stream, forward_context, runtime_shape):
scale_value=scale,
block_table=block_table,
context_lens=seq_lens,
out=output)
out=output,
)
torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu._npu_paged_attention(query=query,
key_cache=key_cache,
value_cache=value_cache,
num_kv_heads=num_kv_heads,
num_heads=num_heads,
scale_value=scale,
block_table=block_table,
context_lens=seq_lens,
out=output,
workspace=workspace)
torch_npu._npu_paged_attention(
query=query,
key_cache=key_cache,
value_cache=value_cache,
num_kv_heads=num_kv_heads,
num_heads=num_heads,
scale_value=scale,
block_table=block_table,
context_lens=seq_lens,
out=output,
workspace=workspace,
)
torch.npu.graph_task_update_end(update_stream)
event.record(update_stream)
@@ -282,18 +278,29 @@ def _update_attn_fia_params(update_stream, forward_context, runtime_shape):
# filters out the update operations for linear_attn.
with torch.npu.stream(update_stream):
for key, param, handle, event in zip(
forward_context.attn_metadata,
graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
forward_context.attn_metadata,
graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
):
(query, key_cache, value, block_tables, attn_mask, block_size,
seq_lens, query_start_loc, num_kv_heads, num_heads, scale,
attn_output, softmax_lse) = param
(
query,
key_cache,
value,
block_tables,
attn_mask,
block_size,
seq_lens,
query_start_loc,
num_kv_heads,
num_heads,
scale,
attn_output,
softmax_lse,
) = param
seq_lens = forward_context.attn_metadata[key].seq_lens_list
actual_seq_lengths_q = forward_context.attn_metadata[
key].actual_seq_lengths_q
actual_seq_lengths_q = forward_context.attn_metadata[key].actual_seq_lengths_q
torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu.npu_fused_infer_attention_score.out(
query=query,
@@ -317,16 +324,14 @@ def _update_attn_fia_params(update_stream, forward_context, runtime_shape):
event.record(update_stream)
def update_attn_params(update_stream, forward_context, runtime_shape,
vllm_config):
def update_attn_params(update_stream, forward_context, runtime_shape, vllm_config):
if using_paged_attention(runtime_shape, vllm_config):
_update_attn_pa_params(update_stream, forward_context, runtime_shape)
else:
_update_attn_fia_params(update_stream, forward_context, runtime_shape)
def update_mla_attn_params(update_stream, forward_context, runtime_shape,
speculative_config):
def update_mla_attn_params(update_stream, forward_context, runtime_shape, speculative_config):
if forward_context.is_draft_model:
graph_params = get_draft_graph_params()
else:
@@ -335,41 +340,44 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
# for each layer's attention op in the graph.
with torch.npu.stream(update_stream):
for key, param, handle, event in zip(
forward_context.attn_metadata,
graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
forward_context.attn_metadata,
graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
):
(q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout,
attn_mask, sparse_mode, scale, block_table, block_size,
seq_lens_list, actual_seq_lengths, attn_output,
softmax_lse) = param
seq_lens_list = forward_context.attn_metadata[
key].decode.seq_lens_list
if speculative_config and speculative_config.method == "mtp" \
and not forward_context.is_draft_model:
actual_seq_lengths = forward_context.attn_metadata[
key].decode.actual_seq_lengths_q
(
q_nope,
k_nope,
q_pe,
k_pe,
num_heads,
num_kv_heads,
input_layout,
attn_mask,
sparse_mode,
scale,
block_table,
block_size,
seq_lens_list,
actual_seq_lengths,
attn_output,
softmax_lse,
) = param
seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list
if speculative_config and speculative_config.method == "mtp" and not forward_context.is_draft_model:
actual_seq_lengths = forward_context.attn_metadata[key].decode.actual_seq_lengths_q
spec_multiple = speculative_config.num_speculative_tokens + 1
seq_lens_list = seq_lens_list + [0] * (
runtime_shape // spec_multiple - len(seq_lens_list))
actual_seq_lengths = [
spec_multiple * (i + 1)
for i in range(runtime_shape // spec_multiple)
]
seq_lens_list = seq_lens_list + [0] * (runtime_shape // spec_multiple - len(seq_lens_list))
actual_seq_lengths = [spec_multiple * (i + 1) for i in range(runtime_shape // spec_multiple)]
elif forward_context.is_draft_model:
actual_seq_lengths = forward_context.attn_metadata[
key].decode.actual_seq_lengths_q
block_table = forward_context.attn_metadata[
key].decode.block_table
actual_seq_lengths = forward_context.attn_metadata[key].decode.actual_seq_lengths_q
block_table = forward_context.attn_metadata[key].decode.block_table
# TODO: This is a hack and should be fixed in the future.
if speculative_config.disable_padded_drafter_batch:
block_table = block_table[:len(actual_seq_lengths)]
seq_lens_list = seq_lens_list + [0] * (
len(actual_seq_lengths) - len(seq_lens_list))
block_table = block_table[: len(actual_seq_lengths)]
seq_lens_list = seq_lens_list + [0] * (len(actual_seq_lengths) - len(seq_lens_list))
else:
seq_lens_list = seq_lens_list + [0] * (runtime_shape -
len(seq_lens_list))
seq_lens_list = seq_lens_list + [0] * (runtime_shape - len(seq_lens_list))
torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu.npu_fused_infer_attention_score.out(
@@ -391,7 +399,8 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
actual_seq_lengths_kv=seq_lens_list,
actual_seq_lengths=actual_seq_lengths,
workspace=graph_params.workspaces.get(runtime_shape),
out=[attn_output, softmax_lse])
out=[attn_output, softmax_lse],
)
torch.npu.graph_task_update_end(update_stream)
event.record(update_stream)
@@ -403,34 +412,40 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
graph_params = get_graph_params()
with torch.npu.stream(update_stream):
for key, param, handle, event in zip(
forward_context.attn_metadata,
graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
forward_context.attn_metadata,
graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
):
(q_nope, k_nope, value, num_heads, num_kv_heads, scale,
block_table, block_size, actual_seq_lengths_kv,
actual_seq_lengths_q, attn_output, softmax_lse, dcp_size,
pcp_rank, dcp_rank) = param
(
q_nope,
k_nope,
value,
num_heads,
num_kv_heads,
scale,
block_table,
block_size,
actual_seq_lengths_kv,
actual_seq_lengths_q,
attn_output,
softmax_lse,
dcp_size,
pcp_rank,
dcp_rank,
) = param
attn_metadata = forward_context.attn_metadata[key]
actual_seq_lengths_kv = attn_metadata.decode_meta.num_computed_tokens_of_pcp_dcp[:,
pcp_rank,
dcp_rank]
actual_seq_lengths_kv = attn_metadata.decode_meta.num_computed_tokens_of_pcp_dcp[:, pcp_rank, dcp_rank]
pad_length = runtime_shape - len(actual_seq_lengths_kv)
if pad_length > 0:
pad_tensor = np.zeros(pad_length,
dtype=actual_seq_lengths_kv.dtype)
actual_seq_lengths_kv = np.concatenate(
[actual_seq_lengths_kv, pad_tensor])
pad_tensor = np.zeros(pad_length, dtype=actual_seq_lengths_kv.dtype)
actual_seq_lengths_kv = np.concatenate([actual_seq_lengths_kv, pad_tensor])
actual_seq_lengths_q = attn_metadata.actual_seq_lengths_q[:
attn_metadata
.
num_decode_tokens]
if (runtime_shape - len(actual_seq_lengths_q)):
actual_seq_lengths_q = actual_seq_lengths_q + [
actual_seq_lengths_q[-1]
] * (runtime_shape - len(actual_seq_lengths_q))
actual_seq_lengths_q = attn_metadata.actual_seq_lengths_q[: attn_metadata.num_decode_tokens]
if runtime_shape - len(actual_seq_lengths_q):
actual_seq_lengths_q = actual_seq_lengths_q + [actual_seq_lengths_q[-1]] * (
runtime_shape - len(actual_seq_lengths_q)
)
if dcp_size > 1:
num_heads = num_heads * dcp_size
@@ -453,14 +468,14 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
actual_seq_lengths_kv=actual_seq_lengths_kv,
actual_seq_lengths=actual_seq_lengths_q,
workspace=graph_params.workspaces.get(runtime_shape),
out=[attn_output, softmax_lse])
out=[attn_output, softmax_lse],
)
torch.npu.graph_task_update_end(update_stream)
event.record(update_stream)
def update_mla_attn_dcp_pcp_params(update_stream, forward_context,
runtime_shape):
def update_mla_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
if forward_context.is_draft_model:
graph_params = get_draft_graph_params()
else:
@@ -469,13 +484,24 @@ def update_mla_attn_dcp_pcp_params(update_stream, forward_context,
# for each layer's attention op in the graph.
with torch.npu.stream(update_stream):
for key, param, handle, event in zip(
forward_context.attn_metadata,
graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
forward_context.attn_metadata,
graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
):
(q_nope, q_pe, k_nope, k_pe, block_table, seq_len, num_heads,
scale, num_kv_heads, attn_output, softmax_lse) = param
(
q_nope,
q_pe,
k_nope,
k_pe,
block_table,
seq_len,
num_heads,
scale,
num_kv_heads,
attn_output,
softmax_lse,
) = param
decode_meta = forward_context.attn_metadata[key].decode
seq_len = decode_meta.cp_seq_len
@@ -484,9 +510,7 @@ def update_mla_attn_dcp_pcp_params(update_stream, forward_context,
# to avoid irregular attn_mask shape,
# so there's no need to divide runtime_shape by spec_multiple
pad_length = runtime_shape - len(seq_len)
pad_tensor = torch.zeros(pad_length,
dtype=seq_len.dtype,
device=seq_len.device)
pad_tensor = torch.zeros(pad_length, dtype=seq_len.dtype, device=seq_len.device)
seq_len = torch.cat([seq_len, pad_tensor], dim=0)
torch.npu.graph_task_update_begin(update_stream, handle)
@@ -505,7 +529,8 @@ def update_mla_attn_dcp_pcp_params(update_stream, forward_context,
calc_type="calc_type_ring",
workspace=graph_params.workspaces.get(runtime_shape),
output=attn_output,
lse=softmax_lse)
lse=softmax_lse,
)
torch.npu.graph_task_update_end(update_stream)
event.record(update_stream)
@@ -519,7 +544,7 @@ class GraphParams:
attn_params: dict[int, list[tuple]]
_graph_params: Optional[GraphParams] = None
_graph_params: GraphParams | None = None
def set_graph_params(aclgraph_capture_sizes: list[int]):
@@ -527,14 +552,10 @@ def set_graph_params(aclgraph_capture_sizes: list[int]):
if _graph_params is not None:
raise ValueError("Graph parameters have already been set!")
_graph_params = GraphParams(
{size: []
for size in aclgraph_capture_sizes},
{size: None
for size in aclgraph_capture_sizes},
{size: []
for size in aclgraph_capture_sizes},
{size: []
for size in aclgraph_capture_sizes},
{size: [] for size in aclgraph_capture_sizes},
{size: None for size in aclgraph_capture_sizes},
{size: [] for size in aclgraph_capture_sizes},
{size: [] for size in aclgraph_capture_sizes},
)
@@ -548,7 +569,7 @@ def get_graph_params():
return _graph_params
_draft_graph_params: Optional[GraphParams] = None
_draft_graph_params: GraphParams | None = None
def set_draft_graph_params(aclgraph_capture_sizes: list[int]):
@@ -556,14 +577,10 @@ def set_draft_graph_params(aclgraph_capture_sizes: list[int]):
if _draft_graph_params is not None:
raise ValueError("DraftGraph parameters have already been set!")
_draft_graph_params = GraphParams(
{size: []
for size in aclgraph_capture_sizes},
{size: None
for size in aclgraph_capture_sizes},
{size: []
for size in aclgraph_capture_sizes},
{size: []
for size in aclgraph_capture_sizes},
{size: [] for size in aclgraph_capture_sizes},
{size: None for size in aclgraph_capture_sizes},
{size: [] for size in aclgraph_capture_sizes},
{size: [] for size in aclgraph_capture_sizes},
)