[Lint]Style: Convert vllm-ascend/ to ruff format(Batch #2) (#5977)

### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
| `vllm_ascend/attention/attention_mask.py` |
| `vllm_ascend/attention/attention_v1.py` |
| `vllm_ascend/attention/context_parallel/attention_cp.py` |
| `vllm_ascend/attention/context_parallel/common_cp.py` |
| `vllm_ascend/attention/context_parallel/mla_cp.py` |
| `vllm_ascend/attention/utils.py` |
| `vllm_ascend/batch_invariant.py` |
| `vllm_ascend/device/device_op.py` |
| `vllm_ascend/device_allocator/camem.py` |
| `vllm_ascend/envs.py` |


- vLLM version: v0.13.0
- vLLM main:
2c24bc6996

---------

Signed-off-by: MrZ20 <2609716663@qq.com>
This commit is contained in:
SILONG ZENG
2026-01-19 08:59:46 +08:00
committed by GitHub
parent 2b6dc100b5
commit 329961b375
11 changed files with 920 additions and 1045 deletions

View File

@@ -1,18 +1,15 @@
from dataclasses import dataclass, field
from functools import lru_cache
from typing import Any, List, Optional
from typing import Any
import torch
import torch.nn.functional as F
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group,
is_v1_kv_transfer_group)
from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group, is_v1_kv_transfer_group
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm_ascend.utils import (AscendDeviceType, get_ascend_config,
get_ascend_device_type)
from vllm_ascend.utils import AscendDeviceType, get_ascend_config, get_ascend_device_type
def using_paged_attention(runtime_shape: int, vllm_config: VllmConfig) -> bool:
@@ -21,6 +18,7 @@ def using_paged_attention(runtime_shape: int, vllm_config: VllmConfig) -> bool:
if get_ascend_device_type() == AscendDeviceType.A5:
return False
from vllm.config.compilation import CUDAGraphMode
cudagraph_mode = vllm_config.compilation_config.cudagraph_mode
if cudagraph_mode != CUDAGraphMode.FULL_DECODE_ONLY:
return False
@@ -31,8 +29,7 @@ def using_paged_attention(runtime_shape: int, vllm_config: VllmConfig) -> bool:
@lru_cache(maxsize=1)
def enable_cp():
prefill_config = get_current_vllm_config().parallel_config
return prefill_config.prefill_context_parallel_size > 1 \
or prefill_config.decode_context_parallel_size > 1
return prefill_config.prefill_context_parallel_size > 1 or prefill_config.decode_context_parallel_size > 1
@dataclass
@@ -42,13 +39,14 @@ class AscendPrefillContextParallelMetadata:
Contains index tensors and sequence lengths for PCP operations.
"""
pcp_allgather_restore_idx: torch.Tensor = None
cp_kv_recover_idx_for_chunk: torch.Tensor = None
num_actual_tokens_pcp_padded: int = 0
num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None
num_computed_tokens_of_pcp_dcp: list[list[list[int]]] | None = None
q_head_idx_tensor: torch.Tensor = None
@@ -85,6 +83,7 @@ class AscendCommonAttentionMetadata(CommonAttentionMetadata):
For many of the tensors we keep both NPU and CPU versions.
"""
# CPU tensor of sequence lengths for host-side operations.
# E.g., tensor([128, 256, 64]) for 3 requests with different seq lengths.
seq_lens_cpu: torch.Tensor = None
@@ -115,20 +114,17 @@ class AscendCommonAttentionMetadata(CommonAttentionMetadata):
num_input_tokens: int = 0
# Metadata for Prefill Context Parallelism (PCP) operations.
prefill_context_parallel_metadata: Optional[
AscendPrefillContextParallelMetadata] = None
prefill_context_parallel_metadata: AscendPrefillContextParallelMetadata | None = None
# TODO: Remove it when vLLM no longer uses this function.
def unpadded(self, num_actual_tokens: int,
num_actual_reqs: int) -> "AscendCommonAttentionMetadata":
def unpadded(self, num_actual_tokens: int, num_actual_reqs: int) -> "AscendCommonAttentionMetadata":
# This only use to eagle now. It will be use to enforce_eager in future.
return AscendCommonAttentionMetadata(
query_start_loc=self.query_start_loc[:num_actual_reqs + 1],
query_start_loc_cpu=self.query_start_loc_cpu[:num_actual_reqs + 1],
query_start_loc=self.query_start_loc[: num_actual_reqs + 1],
query_start_loc_cpu=self.query_start_loc_cpu[: num_actual_reqs + 1],
seq_lens=self.seq_lens[:num_actual_reqs],
seq_lens_cpu=self.seq_lens_cpu[:num_actual_reqs],
num_computed_tokens_cpu=self.
num_computed_tokens_cpu[:num_actual_reqs],
num_computed_tokens_cpu=self.num_computed_tokens_cpu[:num_actual_reqs],
num_reqs=num_actual_reqs,
num_actual_tokens=num_actual_tokens,
max_query_len=self.max_query_len,
@@ -144,14 +140,14 @@ class AscendCommonAttentionMetadata(CommonAttentionMetadata):
attn_state=self.attn_state,
graph_pad_size=-1, # It should be -1 when not run in fullgraph mode.
num_input_tokens=self.num_input_tokens,
prefill_context_parallel_metadata=self.
prefill_context_parallel_metadata,
max_seq_len=self.max_seq_len)
prefill_context_parallel_metadata=self.prefill_context_parallel_metadata,
max_seq_len=self.max_seq_len,
)
def filter_chunked_req_indices(
seq_len: torch.Tensor,
mask_for_non_zero_chunk: Optional[List[bool]],
mask_for_non_zero_chunk: list[bool] | None,
) -> torch.Tensor:
"""
filter the reqs which are doing real chunk_prefill.
@@ -162,14 +158,15 @@ def filter_chunked_req_indices(
Returns:
filtered_indices: the real chunked req's indices
"""
assert mask_for_non_zero_chunk is not None and len(seq_len) == len(
mask_for_non_zero_chunk)
assert mask_for_non_zero_chunk is not None and len(seq_len) == len(mask_for_non_zero_chunk)
offsets = torch.cumsum(torch.cat([torch.tensor([0]), seq_len[:-1]]), dim=0)
filtered_indices = torch.cat([
torch.arange(offsets[i], offsets[i] + seq_len[i])
for i in range(len(mask_for_non_zero_chunk))
if mask_for_non_zero_chunk[i]
])
filtered_indices = torch.cat(
[
torch.arange(offsets[i], offsets[i] + seq_len[i])
for i in range(len(mask_for_non_zero_chunk))
if mask_for_non_zero_chunk[i]
]
)
return filtered_indices
@@ -195,12 +192,9 @@ def split_decodes_and_prefills(
num_prefill_tokens: The number of tokens in the prefill requests.
"""
long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
query_lens_pcp_full = long_seq_metadata.query_lens_pcp_full_cpu \
if long_seq_metadata else None
max_query_len_pcp_full = long_seq_metadata.max_query_len_pcp_full \
if long_seq_metadata else 0
max_query_len = common_attn_metadata.max_query_len \
if max_query_len_pcp_full == 0 else max_query_len_pcp_full
query_lens_pcp_full = long_seq_metadata.query_lens_pcp_full_cpu if long_seq_metadata else None
max_query_len_pcp_full = long_seq_metadata.max_query_len_pcp_full if long_seq_metadata else 0
max_query_len = common_attn_metadata.max_query_len if max_query_len_pcp_full == 0 else max_query_len_pcp_full
num_reqs = common_attn_metadata.num_reqs
num_tokens = common_attn_metadata.num_actual_tokens
query_start_loc = common_attn_metadata.query_start_loc_cpu
@@ -208,8 +202,7 @@ def split_decodes_and_prefills(
if max_query_len <= decode_threshold:
return num_reqs, 0, num_tokens, 0
query_lens = (query_start_loc[1:] - query_start_loc[:-1]) \
if query_lens_pcp_full is None else query_lens_pcp_full
query_lens = (query_start_loc[1:] - query_start_loc[:-1]) if query_lens_pcp_full is None else query_lens_pcp_full
is_prefill = query_lens > decode_threshold
if not torch.any(is_prefill):
return num_reqs, 0, num_tokens, 0
@@ -238,7 +231,7 @@ def wait_for_kv_layer_from_connector(layer_name: str):
def maybe_save_kv_layer_to_connector(
layer_name: str,
kv_cache_layer: List[torch.Tensor],
kv_cache_layer: list[torch.Tensor],
):
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
return
@@ -264,8 +257,7 @@ def trans_rope_weight(weight, rope_dim):
return weight.contiguous()
nope_part = weight[..., :-rope_dim, :]
rope_part = weight[..., -rope_dim:, :]
reordered_rope_part = torch.cat(
(rope_part[..., ::2, :], rope_part[..., 1::2, :]), dim=-2)
reordered_rope_part = torch.cat((rope_part[..., ::2, :], rope_part[..., 1::2, :]), dim=-2)
return torch.cat((nope_part, reordered_rope_part), dim=-2).contiguous()
@@ -278,12 +270,9 @@ def transdata(nd_mat, block_size: tuple = (16, 16)):
nz_mat = torch.permute(
torch.reshape(
nd_mat,
(r // block_size[0], block_size[0], c // block_size[1],
block_size[1]),
(r // block_size[0], block_size[0], c // block_size[1], block_size[1]),
),
[2, 0, 1, 3],
)
nz_mat = torch.reshape(
nz_mat,
(nz_mat.shape[0], nz_mat.shape[1] * nz_mat.shape[2], nz_mat.shape[3]))
nz_mat = torch.reshape(nz_mat, (nz_mat.shape[0], nz_mat.shape[1] * nz_mat.shape[2], nz_mat.shape[3]))
return nz_mat