From 7faa6878a664f653469e4f77e981286d086b5570 Mon Sep 17 00:00:00 2001 From: SILONG ZENG <2609716663@qq.com> Date: Sat, 24 Jan 2026 22:10:18 +0800 Subject: [PATCH] [Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #3) (#5978) ### What this PR does / why we need it? **Scope of Changes**: | File Path | | :--- | | `vllm_ascend/attention/mla_v1.py` | | `vllm_ascend/attention/sfa_v1.py` | | `vllm_ascend/core/recompute_scheduler.py` | | `vllm_ascend/core/scheduler_dynamic_batch.py` | | `vllm_ascend/distributed/device_communicators/npu_communicator.py` | | `vllm_ascend/distributed/device_communicators/pyhccl.py` | | `vllm_ascend/distributed/device_communicators/pyhccl_wrapper.py` | ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: MrZ20 <2609716663@qq.com> Co-authored-by: Soren --- pyproject.toml | 5 - .../attention/context_parallel/mla_cp.py | 2 +- vllm_ascend/attention/mla_v1.py | 885 ++++++++---------- vllm_ascend/attention/sfa_v1.py | 537 +++++------ vllm_ascend/core/recompute_scheduler.py | 239 ++--- vllm_ascend/core/scheduler_dynamic_batch.py | 197 ++-- .../device_communicators/npu_communicator.py | 51 +- .../device_communicators/pyhccl.py | 63 +- .../device_communicators/pyhccl_wrapper.py | 122 +-- 9 files changed, 953 insertions(+), 1148 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3f053ec5..4b32da9e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,11 +51,6 @@ line-length = 120 # Folder to be modified exclude = [ "tests/**", - # (3) - "vllm_ascend/attention/*.py", - "vllm_ascend/core/*.py", - "vllm_ascend/distributed/device_communicators/**", - "vllm_ascend/distributed/utils.py", # (5) "vllm_ascend/distributed/kv_transfer/kv_pool/**", "vllm_ascend/distributed/kv_transfer/utils/**", diff --git a/vllm_ascend/attention/context_parallel/mla_cp.py b/vllm_ascend/attention/context_parallel/mla_cp.py index e0b77a0f..e0cd7998 100644 --- a/vllm_ascend/attention/context_parallel/mla_cp.py +++ b/vllm_ascend/attention/context_parallel/mla_cp.py @@ -394,7 +394,7 @@ class AscendMlaCPImpl(AscendMLAImpl): prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin) prefill_kv_no_split = kv_no_split[:num_actual_tokens] kv_c, k_pe = prefill_kv_no_split.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) + kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) # type: ignore[misc] assert len(kv_cache) > 1, "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)" kv_c_normed = kv_c_normed.view([num_actual_tokens, self.num_kv_heads, -1]) k_pe = k_pe.unsqueeze(1) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index d625a3bb..5b81f3ba 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Type, TypeVar +from typing import TYPE_CHECKING, NamedTuple, TypeVar import numpy as np import torch @@ -10,35 +10,42 @@ from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import logger from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm.utils.math_utils import cdiv, round_down -from vllm.v1.attention.backend import ( # type: ignore - AttentionBackend, AttentionCGSupport, MLAAttentionImpl) +from vllm.v1.attention.backend import AttentionBackend, AttentionCGSupport, MLAAttentionImpl # type: ignore from vllm.v1.attention.backends.mla.common import MLACommonMetadataBuilder from vllm.v1.attention.backends.utils import PAD_SLOT_ID # type: ignore from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec -from vllm_ascend import envs from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import AscendAttentionState -from vllm_ascend.attention.context_parallel.common_cp import ( - AscendPCPMetadata, CPChunkedContextMetadata) +from vllm_ascend.attention.context_parallel.common_cp import AscendPCPMetadata, CPChunkedContextMetadata from vllm_ascend.attention.utils import ( - AscendCommonAttentionMetadata, ascend_chunked_prefill_workspace_size, - enable_cp, maybe_save_kv_layer_to_connector, split_decodes_and_prefills, - trans_rope_weight, transdata, wait_for_kv_layer_from_connector, - enabling_malpo) + AscendCommonAttentionMetadata, + ascend_chunked_prefill_workspace_size, + enable_cp, + enabling_malpo, + maybe_save_kv_layer_to_connector, + split_decodes_and_prefills, + trans_rope_weight, + transdata, + wait_for_kv_layer_from_connector, +) from vllm_ascend.compilation.acl_graph import ( - get_draft_graph_params, get_graph_params, - update_draft_graph_params_workspaces, update_graph_params_workspaces) + get_draft_graph_params, + get_graph_params, + update_draft_graph_params_workspaces, + update_graph_params_workspaces, +) from vllm_ascend.ops.layer_shard_linear import ( - is_hidden_layer, post_process_after_loading_for_shard_weight_series, + is_hidden_layer, + post_process_after_loading_for_shard_weight_series, reach_layer_for_shard_weight_series, - register_all_layers_to_shard_weight_series) + register_all_layers_to_shard_weight_series, +) from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch from vllm_ascend.quantization.methods import AscendW8A8LinearMethod -from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, maybe_trans_nz, - weak_ref_tensors) +from vllm_ascend.utils import ACL_FORMAT_FRACTAL_ND, maybe_trans_nz, weak_ref_tensors from vllm_ascend.worker.npu_input_batch import NPUInputBatch if TYPE_CHECKING: @@ -65,21 +72,20 @@ class AscendMLABackend(AttentionBackend): @staticmethod def get_builder_cls(): if enable_cp(): - from vllm_ascend.attention.context_parallel.mla_cp import \ - AscendMlaCPMetadataBuilder + from vllm_ascend.attention.context_parallel.mla_cp import AscendMlaCPMetadataBuilder + return AscendMlaCPMetadataBuilder return AscendMLAMetadataBuilder @staticmethod - def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int, - head_size: int) -> tuple[int, ...]: + def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int, head_size: int) -> tuple[int, ...]: return num_blocks, block_size, num_kv_heads, head_size @staticmethod - def get_impl_cls() -> Type["MLAAttentionImpl"]: + def get_impl_cls() -> type["MLAAttentionImpl"]: if enable_cp(): - from vllm_ascend.attention.context_parallel.mla_cp import \ - AscendMlaCPImpl + from vllm_ascend.attention.context_parallel.mla_cp import AscendMlaCPImpl + return AscendMlaCPImpl return AscendMLAImpl @@ -91,6 +97,7 @@ class ChunkedContextMetadata: Manages sequence boundaries and workspace for chunked prefill processing. """ + cu_seq_lens: torch.Tensor starts: torch.Tensor seq_tot: list[int] @@ -102,7 +109,8 @@ class ChunkedContextMetadata: @dataclass class AscendMLAPrefillMetadata: - """ Prefill Specific Metadata for Ascend""" + """Prefill Specific Metadata for Ascend""" + attn_mask: torch.Tensor query_lens: torch.Tensor seq_lens: list[int] @@ -112,16 +120,16 @@ class AscendMLAPrefillMetadata: block_table: torch.Tensor max_query_len: int max_seq_lens: int - chunked_context: Optional[ChunkedContextMetadata - | CPChunkedContextMetadata] = None + chunked_context: ChunkedContextMetadata | CPChunkedContextMetadata | None = None sin: torch.Tensor = None cos: torch.Tensor = None - pcp_metadata: Optional[AscendPCPMetadata] = None + pcp_metadata: AscendPCPMetadata | None = None @dataclass class AscendMLADecodeMetadata: - """ Decode-specific metadata for Ascend MLA attention.""" + """Decode-specific metadata for Ascend MLA attention.""" + # Input positions for rotary embeddings since for MLA the rotary # position embeddings are applied inside the attention backend input_positions: torch.Tensor @@ -129,8 +137,8 @@ class AscendMLADecodeMetadata: seq_lens: torch.Tensor max_seq_lens: int seq_lens_list: list[int] - actual_seq_lengths_q: Optional[list[int]] = None - attn_mask: Optional[torch.Tensor] = None + actual_seq_lengths_q: list[int] | None = None + attn_mask: torch.Tensor | None = None sin: torch.Tensor = None cos: torch.Tensor = None cp_seq_len: torch.Tensor = None @@ -142,6 +150,7 @@ class AscendMLAMetadata: NOTE: Please read the comment at the top of the file before trying to understand this class """ + # NOTE(sang): Definition of context_len, query_len, and seq_len. # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| @@ -166,15 +175,15 @@ class AscendMLAMetadata: # For logging. num_input_tokens: int = 0 # Number of tokens including padding. - query_lens: Optional[list[int]] = None + query_lens: list[int] | None = None # The dimension of the attention heads - head_dim: Optional[int] = None + head_dim: int | None = None attn_mask: torch.Tensor = None # chunked prefill by default if no attn_states passed attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill - decode: Optional[AscendMLADecodeMetadata] = None - prefill: Optional[AscendMLAPrefillMetadata] = None + decode: AscendMLADecodeMetadata | None = None + prefill: AscendMLAPrefillMetadata | None = None reshape_cache_event: torch.npu.Event = None def __post_init__(self): @@ -206,14 +215,17 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]): supports_dcp_with_varlen: bool = False, ): super().__init__( - kv_cache_spec, layer_names, vllm_config, device, + kv_cache_spec, + layer_names, + vllm_config, + device, metadata_cls if metadata_cls is not None else AscendMLAMetadata, - supports_dcp_with_varlen) + supports_dcp_with_varlen, + ) scheduler_config = vllm_config.scheduler_config self.block_size = vllm_config.cache_config.block_size - self.max_blocks = (vllm_config.model_config.max_model_len + - self.block_size - 1) // self.block_size + self.max_blocks = (vllm_config.model_config.max_model_len + self.block_size - 1) // self.block_size self.chunked_prefill_enabled = scheduler_config.enable_chunked_prefill self.speculative_config = vllm_config.speculative_config @@ -221,12 +233,13 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]): if self.speculative_config: spec_token_num = self.speculative_config.num_speculative_tokens self.decode_threshold += spec_token_num - assert self.decode_threshold <= 16, f"decode_threshold exceeded \ + assert self.decode_threshold <= 16, ( + f"decode_threshold exceeded \ npu_fused_infer_attention_score TND layout's limit of 16, \ got {self.decode_threshold}" + ) self.reorder_batch_threshold = self.decode_threshold - self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim self.cos_cache = None self.sin_cache = None @@ -240,7 +253,7 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]): self.num_decode_tokens = 0 self.num_prefill_tokens = 0 self.context_lens_cpu: torch.Tensor = None - self.num_actual_tokens: Optional[int] = None + self.num_actual_tokens: int | None = None self.block_table: torch.Tensor = None self.slot_mapping: torch.Tensor = None self.graph_pad_size = 0 @@ -249,8 +262,7 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]): self.attn_mask_builder = AttentionMaskBuilder(self.device) @staticmethod - def determine_chunked_prefill_workspace_size( - vllm_config: VllmConfig) -> int: + def determine_chunked_prefill_workspace_size(vllm_config: VllmConfig) -> int: return ascend_chunked_prefill_workspace_size(vllm_config) @classmethod @@ -263,8 +275,7 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]): # @override omitted only because of mypy limitation due to type variable. return AttentionCGSupport.UNIFORM_BATCH - def reorder_batch(self, input_batch: "NPUInputBatch", - scheduler_output: "SchedulerOutput") -> bool: + def reorder_batch(self, input_batch: "NPUInputBatch", scheduler_output: "SchedulerOutput") -> bool: # We now want to reorder the batch so that the "decode" requests are at # the front and the "prefill" requests are at the using the least amount # swaps possible. (NOTE for now we loosely use "decode" to mean requests @@ -300,8 +311,7 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]): # If the decode is at the "back" of the batch, i, we can swap it # with the prefill closest to the front of the batch if decodes[num_decodes - i] >= num_decodes: - input_batch.swap_states(prefills[first_prefill], - decodes[num_decodes - i]) + input_batch.swap_states(prefills[first_prefill], decodes[num_decodes - i]) first_prefill += 1 modified_batch = True else: @@ -312,9 +322,9 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]): # better way of doing this return modified_batch - def pad_actual_seq_len_q_mtp_enable_pad(self, num_reqs_pad_size, num_reqs, - actual_seq_lengths_q, - common_attn_metadata): + def pad_actual_seq_len_q_mtp_enable_pad( + self, num_reqs_pad_size, num_reqs, actual_seq_lengths_q, common_attn_metadata + ): """ Pads actual_seq_lengths_q evenly to not exceed 16 tokens per request in order to meet the requirement of npu_fused_infer_attention_score. @@ -329,35 +339,35 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]): However, mtp torchair + PD scenario, the actual_seq_lengths_q may be [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16] before padding, since the first decode request only has 1 token. - In order to meet the requirement of npu_fused_infer_attention_score, we need to pad actual_seq_lengths_q evenly to not exceed 16 tokens per request. + In order to meet the requirement of npu_fused_infer_attention_score, we need to pad actual_seq_lengths_q + evenly to not exceed 16 tokens per request. after padding actual_seq_lengths_q should be similar to [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,32,36] """ FIA_SEQ_LEN_LIMIT = 16 - need_padding = num_reqs_pad_size != 0 and \ - len(common_attn_metadata.actual_seq_lengths_q) > num_reqs and \ - common_attn_metadata.actual_seq_lengths_q[num_reqs] - actual_seq_lengths_q[ - -1] > FIA_SEQ_LEN_LIMIT + need_padding = ( + num_reqs_pad_size != 0 + and len(common_attn_metadata.actual_seq_lengths_q) > num_reqs + and common_attn_metadata.actual_seq_lengths_q[num_reqs] - actual_seq_lengths_q[-1] > FIA_SEQ_LEN_LIMIT + ) if need_padding: - padding_seq_len_q = common_attn_metadata.actual_seq_lengths_q[ - num_reqs:num_reqs + num_reqs_pad_size] + padding_seq_len_q = common_attn_metadata.actual_seq_lengths_q[num_reqs : num_reqs + num_reqs_pad_size] start_val = actual_seq_lengths_q[-1] end_val = padding_seq_len_q[-1] num_step = len(padding_seq_len_q) - interpolated = np.round( - np.linspace(start_val, end_val, - num_step + 1)[1:]).astype(int).tolist() + interpolated = np.round(np.linspace(start_val, end_val, num_step + 1)[1:]).astype(int).tolist() assert interpolated[-1] == end_val assert len(interpolated) == len(padding_seq_len_q) actual_seq_lengths_q = actual_seq_lengths_q + interpolated else: - actual_seq_lengths_q = actual_seq_lengths_q + common_attn_metadata.actual_seq_lengths_q[ - num_reqs:num_reqs + num_reqs_pad_size] + actual_seq_lengths_q = ( + actual_seq_lengths_q + + common_attn_metadata.actual_seq_lengths_q[num_reqs : num_reqs + num_reqs_pad_size] + ) return actual_seq_lengths_q - def pad_actual_seq_len_q_mtp_disable_pad(self, num_reqs_pad_size, num_reqs, - actual_seq_lengths_q): + def pad_actual_seq_len_q_mtp_disable_pad(self, num_reqs_pad_size, num_reqs, actual_seq_lengths_q): """ Only use for acl full graph mode. Pad the last element of the actual_seq_lengths_q equal to the TND(T) and @@ -373,9 +383,7 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]): start_val = actual_seq_lengths_q[-1] end_val = num_reqs + num_reqs_pad_size num_step = num_reqs_pad_size - interpolated = np.round( - np.linspace(start_val, end_val, - num_step + 1)[1:]).astype(int).tolist() + interpolated = np.round(np.linspace(start_val, end_val, num_step + 1)[1:]).astype(int).tolist() assert interpolated[-1] == end_val assert len(interpolated) == num_reqs_pad_size actual_seq_lengths_q = actual_seq_lengths_q + interpolated @@ -397,35 +405,31 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]): query_start_loc = common_attn_metadata.query_start_loc query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu - self.num_decodes, self.num_prefills, self.num_decode_tokens, self.num_prefill_tokens = \ + self.num_decodes, self.num_prefills, self.num_decode_tokens, self.num_prefill_tokens = ( split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold) + ) self.set_num_actual_tokens(common_attn_metadata) assert self.num_decodes + self.num_prefills == num_reqs assert self.num_decode_tokens + self.num_prefill_tokens == common_attn_metadata.num_actual_tokens # NOTE: Currently, MTP-fullgraph is incompatibility pcp - self.slot_mapping = common_attn_metadata.slot_mapping[:self. - num_actual_tokens] + self.slot_mapping = common_attn_metadata.slot_mapping[: self.num_actual_tokens] query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] self.query_lens = query_seq_lens_cpu[:num_reqs] self.seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] self.graph_pad_size = common_attn_metadata.graph_pad_size - block_table_size = self.get_block_table_size( - common_attn_metadata, BUILD_METADATA_STEP_PREFILL) - self.block_table = common_attn_metadata.block_table_tensor[: - block_table_size] + block_table_size = self.get_block_table_size(common_attn_metadata, BUILD_METADATA_STEP_PREFILL) + self.block_table = common_attn_metadata.block_table_tensor[:block_table_size] prefill_metadata = None if self.num_prefills > 0: - prefill_metadata = self.build_prefill_metadata( - common_prefix_len, common_attn_metadata) + prefill_metadata = self.build_prefill_metadata(common_prefix_len, common_attn_metadata) decode_metadata = None if self.num_decodes > 0: - decode_metadata = self.build_decode_metadata( - common_prefix_len, common_attn_metadata) + decode_metadata = self.build_decode_metadata(common_prefix_len, common_attn_metadata) return self.metadata_cls( # type: ignore num_actual_tokens_pcp_padded=self.num_actual_tokens, num_input_tokens=common_attn_metadata.num_input_tokens, @@ -436,8 +440,7 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]): num_decodes=self.num_decodes, num_decode_tokens=self.num_decode_tokens, num_prefills=self.num_prefills, - attn_mask=self.attn_mask_builder.get_final_mla_mask( - self.model_config), + attn_mask=self.attn_mask_builder.get_final_mla_mask(self.model_config), attn_state=common_attn_metadata.attn_state, prefill=prefill_metadata, decode=decode_metadata, @@ -455,40 +458,30 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]): return None num_reqs = common_attn_metadata.num_reqs - num_computed_tokens_cpu = (self.seq_lens - self.query_lens) + num_computed_tokens_cpu = self.seq_lens - self.query_lens reqs_start = self.num_decodes # prefill_start self.context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs] max_context_len_cpu = self.context_lens_cpu.max().item() if not max_context_len_cpu > 0: return None - num_prefills_with_context_cpu = (self.context_lens_cpu - > 0).sum().item() - self.max_context_chunk = (self.chunked_prefill_workspace_size // - num_prefills_with_context_cpu) - self.max_context_chunk = round_down(self.max_context_chunk, - self.block_size) + num_prefills_with_context_cpu = (self.context_lens_cpu > 0).sum().item() + self.max_context_chunk = self.chunked_prefill_workspace_size // num_prefills_with_context_cpu + self.max_context_chunk = round_down(self.max_context_chunk, self.block_size) assert self.max_context_chunk > 0 self.num_chunks = cdiv(max_context_len_cpu, self.max_context_chunk) - chunk_starts = torch.arange(self.num_chunks, dtype=torch.int32) \ - .unsqueeze(1).expand(-1, self.num_prefills) * self.max_context_chunk - chunk_ends = torch.min(self.context_lens_cpu.unsqueeze(0), - chunk_starts + self.max_context_chunk) + chunk_starts = ( + torch.arange(self.num_chunks, dtype=torch.int32).unsqueeze(1).expand(-1, self.num_prefills) + * self.max_context_chunk + ) + chunk_ends = torch.min(self.context_lens_cpu.unsqueeze(0), chunk_starts + self.max_context_chunk) self.chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0) - self.cu_seq_lens_cpu = torch.zeros(self.num_chunks, - self.num_prefills + 1, - dtype=torch.int32, - pin_memory=True) - torch.cumsum(self.chunk_seq_lens, - dim=1, - out=self.cu_seq_lens_cpu[:, 1:], - dtype=torch.int32) + self.cu_seq_lens_cpu = torch.zeros(self.num_chunks, self.num_prefills + 1, dtype=torch.int32, pin_memory=True) + torch.cumsum(self.chunk_seq_lens, dim=1, out=self.cu_seq_lens_cpu[:, 1:], dtype=torch.int32) return ChunkedContextMetadata( - cu_seq_lens=self.cu_seq_lens_cpu.pin_memory().to( - self.device, non_blocking=True), - starts=chunk_starts.pin_memory().to(self.device, - non_blocking=True), + cu_seq_lens=self.cu_seq_lens_cpu.pin_memory().to(self.device, non_blocking=True), + starts=chunk_starts.pin_memory().to(self.device, non_blocking=True), seq_tot=self.chunk_seq_lens.sum(dim=1).tolist(), max_seq_lens=self.chunk_seq_lens.max(dim=1).values.tolist(), chunk_seq_lens=self.chunk_seq_lens, @@ -496,13 +489,14 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]): workspace=self.chunked_prefill_workspace, ) - def get_block_table_size( - self, common_attn_metadata: AscendCommonAttentionMetadata, - build_metadata_step: int): + def get_block_table_size(self, common_attn_metadata: AscendCommonAttentionMetadata, build_metadata_step: int): if build_metadata_step == BUILD_METADATA_STEP_PREFILL: # If graph_pad_size > -1, mean is running in fullgraph mode. # NOTE: Maybe this block_table change can be removed when graph_pad_size > 1. - if self.graph_pad_size > common_attn_metadata.num_reqs and self.speculative_config.disable_padded_drafter_batch: + if ( + self.graph_pad_size > common_attn_metadata.num_reqs + and self.speculative_config.disable_padded_drafter_batch + ): return self.graph_pad_size return common_attn_metadata.num_reqs return self.num_decodes @@ -515,24 +509,19 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]): query_start_loc = common_attn_metadata.query_start_loc # NOTE: Currently, MTP-fullgraph is incompatibility pcp - input_positions = common_attn_metadata.positions[:self. - num_actual_tokens].long( - ) + input_positions = common_attn_metadata.positions[: self.num_actual_tokens].long() - chunked_context_metadata = self.build_chunked_metadata( - common_prefix_len, common_attn_metadata) + chunked_context_metadata = self.build_chunked_metadata(common_prefix_len, common_attn_metadata) reqs_start = self.num_decodes # prefill_start tokens_start = self.num_decode_tokens max_query_len = self.query_lens[reqs_start:].max().item() max_seq_lens = self.seq_lens[reqs_start:].max().item() - prefill_query_start_loc = query_start_loc[ - reqs_start:] - query_start_loc[reqs_start] + prefill_query_start_loc = query_start_loc[reqs_start:] - query_start_loc[reqs_start] prefill_input_positions = input_positions[tokens_start:] cos, sin = get_cos_and_sin_mla(prefill_input_positions) return AscendMLAPrefillMetadata( - attn_mask=self.attn_mask_builder.get_final_mla_mask( - self.model_config), + attn_mask=self.attn_mask_builder.get_final_mla_mask(self.model_config), query_lens=self.query_lens[reqs_start:].to(torch.int32), seq_lens=self.seq_lens, context_lens=self.seq_lens[reqs_start:], @@ -554,26 +543,21 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]): num_reqs = common_attn_metadata.num_reqs query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu - input_positions = common_attn_metadata.positions[:self. - num_actual_tokens].long( - ) + input_positions = common_attn_metadata.positions[: self.num_actual_tokens].long() # Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario - actual_seq_lengths_q = query_start_loc_cpu[1:self.num_decodes + - 1].tolist() - max_seq_lens = self.seq_lens[:self.num_decodes].max().item() - self.seq_lens = self.seq_lens[:self.num_decodes] - input_positions = input_positions[:self.num_decode_tokens] + actual_seq_lengths_q = query_start_loc_cpu[1 : self.num_decodes + 1].tolist() + max_seq_lens = self.seq_lens[: self.num_decodes].max().item() + self.seq_lens = self.seq_lens[: self.num_decodes] + input_positions = input_positions[: self.num_decode_tokens] - block_table_size = self.get_block_table_size( - common_attn_metadata, BUILD_METADATA_STEP_DECODE) + block_table_size = self.get_block_table_size(common_attn_metadata, BUILD_METADATA_STEP_DECODE) self.block_table = self.block_table[:block_table_size] # NOTE: Currently, MTP-fullgraph is incompatibility pcp # NOTE: Maybe this block_table change can be removed when graph_pad_size > 1. - if self.graph_pad_size > self.num_decodes and \ - self.speculative_config.disable_padded_drafter_batch: - self.block_table = self.block_table[:self.graph_pad_size, ...] + if self.graph_pad_size > self.num_decodes and self.speculative_config.disable_padded_drafter_batch: + self.block_table = self.block_table[: self.graph_pad_size, ...] seq_lens_list = self.seq_lens.tolist() cp_seq_len = None @@ -582,49 +566,41 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]): if self.speculative_config.disable_padded_drafter_batch: num_reqs_pad_size = self.graph_pad_size - num_reqs actual_seq_lengths_q = self.pad_actual_seq_len_q_mtp_disable_pad( - num_reqs_pad_size, num_reqs, actual_seq_lengths_q) - seq_lens_list = seq_lens_list + [0] * (self.graph_pad_size - - self.num_decodes) - num_block_pad_size = self.graph_pad_size - self.block_table.shape[ - 0] + num_reqs_pad_size, num_reqs, actual_seq_lengths_q + ) + seq_lens_list = seq_lens_list + [0] * (self.graph_pad_size - self.num_decodes) + num_block_pad_size = self.graph_pad_size - self.block_table.shape[0] if num_block_pad_size > 0: block_table_padding = torch.zeros( - (num_block_pad_size, ) + self.block_table.shape[1:], + (num_block_pad_size,) + self.block_table.shape[1:], dtype=self.block_table.dtype, - device=self.block_table.device) - self.block_table = torch.cat( - [self.block_table, block_table_padding], dim=0) + device=self.block_table.device, + ) + self.block_table = torch.cat([self.block_table, block_table_padding], dim=0) else: num_token_pad_size = self.graph_pad_size - self.num_decode_tokens - num_reqs_pad_size = ( - self.graph_pad_size // - common_attn_metadata.decode_token_per_req - num_reqs) + num_reqs_pad_size = self.graph_pad_size // common_attn_metadata.decode_token_per_req - num_reqs num_block_table_pad_size = ( - self.graph_pad_size // - common_attn_metadata.decode_token_per_req - - self.num_decodes) - seq_lens_list = self.seq_lens.tolist() + [0 - ] * num_reqs_pad_size - slot_padding = torch.full((num_token_pad_size, ), - PAD_SLOT_ID, - dtype=self.slot_mapping.dtype, - device=self.slot_mapping.device) - self.slot_mapping = torch.cat( - [self.slot_mapping, slot_padding]) + self.graph_pad_size // common_attn_metadata.decode_token_per_req - self.num_decodes + ) + seq_lens_list = self.seq_lens.tolist() + [0] * num_reqs_pad_size + slot_padding = torch.full( + (num_token_pad_size,), PAD_SLOT_ID, dtype=self.slot_mapping.dtype, device=self.slot_mapping.device + ) + self.slot_mapping = torch.cat([self.slot_mapping, slot_padding]) block_table_padding = torch.zeros( - (num_block_table_pad_size, ) + self.block_table.shape[1:], + (num_block_table_pad_size,) + self.block_table.shape[1:], dtype=self.block_table.dtype, - device=self.block_table.device) - self.block_table = torch.cat( - [self.block_table, block_table_padding], dim=0) - position_padding = torch.zeros(num_token_pad_size, - dtype=input_positions.dtype, - device=input_positions.device) - input_positions = torch.cat( - [input_positions, position_padding]) + device=self.block_table.device, + ) + self.block_table = torch.cat([self.block_table, block_table_padding], dim=0) + position_padding = torch.zeros( + num_token_pad_size, dtype=input_positions.dtype, device=input_positions.device + ) + input_positions = torch.cat([input_positions, position_padding]) actual_seq_lengths_q = self.pad_actual_seq_len_q_mtp_enable_pad( - num_reqs_pad_size, num_reqs, actual_seq_lengths_q, - common_attn_metadata) + num_reqs_pad_size, num_reqs, actual_seq_lengths_q, common_attn_metadata + ) cos, sin = get_cos_and_sin_mla(input_positions, use_cache=True) decode_metadata = AscendMLADecodeMetadata( @@ -635,9 +611,10 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]): max_seq_lens=max_seq_lens, attn_mask=self.attn_mask_builder.get_splitfuse_attn_mask(), actual_seq_lengths_q=actual_seq_lengths_q, - sin=sin[:self.num_decode_tokens, ...], - cos=cos[:self.num_decode_tokens, ...], - cp_seq_len=cp_seq_len) + sin=sin[: self.num_decode_tokens, ...], + cos=cos[: self.num_decode_tokens, ...], + cp_seq_len=cp_seq_len, + ) return decode_metadata def build_for_graph_capture( @@ -645,10 +622,7 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]): common_attn_metadata: AscendCommonAttentionMetadata, attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly, ): - if attn_state in { - AscendAttentionState.DecodeOnly, - AscendAttentionState.SpecDecoding - }: + if attn_state in {AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding}: attn_metadata = self.build( common_prefix_len=0, common_attn_metadata=common_attn_metadata, @@ -663,19 +637,19 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]): class DecodeMLAPreprocessResult(NamedTuple): - ql_nope: Optional[torch.Tensor] = None - q_pe: Optional[torch.Tensor] = None - k_nope: Optional[torch.Tensor] = None - k_pe: Optional[torch.Tensor] = None - decode_q_wo_k_up: Optional[torch.Tensor] = None + ql_nope: torch.Tensor | None = None + q_pe: torch.Tensor | None = None + k_nope: torch.Tensor | None = None + k_pe: torch.Tensor | None = None + decode_q_wo_k_up: torch.Tensor | None = None class PrefillMLAPreprocessResult(NamedTuple): - q_nope: Optional[torch.Tensor] = None - q_pe: Optional[torch.Tensor] = None - k_nope: Optional[torch.Tensor] = None - k_pe: Optional[torch.Tensor] = None - value: Optional[torch.Tensor] = None + q_nope: torch.Tensor | None = None + q_pe: torch.Tensor | None = None + k_nope: torch.Tensor | None = None + k_pe: torch.Tensor | None = None + value: torch.Tensor | None = None class AscendMLAImpl(MLAAttentionImpl): @@ -690,12 +664,12 @@ class AscendMLAImpl(MLAAttentionImpl): head_size: int, scale: float, num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], + alibi_slopes: list[float] | None, + sliding_window: int | None, kv_cache_dtype: str, - logits_soft_cap: Optional[float], + logits_soft_cap: float | None, attn_type: str, - kv_sharing_target_layer_name: Optional[str], + kv_sharing_target_layer_name: str | None, **kwargs, ): self.vllm_config = get_current_vllm_config() @@ -706,22 +680,21 @@ class AscendMLAImpl(MLAAttentionImpl): self.kv_cache_dtype = kv_cache_dtype # MLA Args - self.q_lora_rank = kwargs['q_lora_rank'] - self.kv_lora_rank = kwargs['kv_lora_rank'] - self.qk_nope_head_dim = kwargs['qk_nope_head_dim'] - self.qk_rope_head_dim = kwargs['qk_rope_head_dim'] - self.qk_head_dim = kwargs['qk_head_dim'] - self.v_head_dim = kwargs['v_head_dim'] - self.rotary_emb = kwargs['rotary_emb'] - self.fused_qkv_a_proj = kwargs.get('fused_qkv_a_proj', None) - self.q_proj = kwargs['q_proj'] if self.q_lora_rank is None else kwargs[ - 'q_b_proj'] - self.kv_b_proj = kwargs['kv_b_proj'] - self.o_proj = kwargs['o_proj'] + self.q_lora_rank = kwargs["q_lora_rank"] + self.kv_lora_rank = kwargs["kv_lora_rank"] + self.qk_nope_head_dim = kwargs["qk_nope_head_dim"] + self.qk_rope_head_dim = kwargs["qk_rope_head_dim"] + self.qk_head_dim = kwargs["qk_head_dim"] + self.v_head_dim = kwargs["v_head_dim"] + self.rotary_emb = kwargs["rotary_emb"] + self.fused_qkv_a_proj = kwargs.get("fused_qkv_a_proj") + self.q_proj = kwargs["q_proj"] if self.q_lora_rank is None else kwargs["q_b_proj"] + self.kv_b_proj = kwargs["kv_b_proj"] + self.o_proj = kwargs["o_proj"] self.vllm_config = get_current_vllm_config() - self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None) - self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None) - self.q_a_layernorm = kwargs.get('q_a_layernorm', None) + self.kv_a_proj_with_mqa = kwargs.get("kv_a_proj_with_mqa") + self.kv_a_layernorm = kwargs.get("kv_a_layernorm") + self.q_a_layernorm = kwargs.get("q_a_layernorm") self.num_queries_per_kv = self.num_heads // self.num_kv_heads ascend_config = get_ascend_config() @@ -734,9 +707,11 @@ class AscendMLAImpl(MLAAttentionImpl): self.speculative_config = self.vllm_config.speculative_config self.enable_mlapo = enabling_malpo(self.vllm_config) - self.is_kv_producer = self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer + self.is_kv_producer = ( + self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer + ) self.layer_sharding_kwargs = [] - for layer_name in (get_ascend_config().layer_sharding or []): + for layer_name in get_ascend_config().layer_sharding or []: if layer_name in kwargs: self.layer_sharding_kwargs.append(kwargs[layer_name]) else: @@ -838,9 +813,11 @@ class AscendMLAImpl(MLAAttentionImpl): # Return `ql_nope`, `q_pe` def _q_proj_and_k_up_proj(self, x): - q_nope, q_pe = self.q_proj(x)[0] \ - .view(-1, self.num_heads, self.qk_head_dim) \ + q_nope, q_pe = ( + self.q_proj(x)[0] + .view(-1, self.num_heads, self.qk_head_dim) .split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + ) # Convert from (B, N, P) to (N, B, P) q_nope = q_nope.transpose(0, 1) @@ -853,24 +830,24 @@ class AscendMLAImpl(MLAAttentionImpl): # NOTE: We currently do not support quant kv_b_proj. assert isinstance(self.kv_b_proj.quant_method, UnquantizedLinearMethod) # NOTE: Weight will be reshaped next, we need to revert and transpose it. - kv_b_proj_weight = torch_npu.npu_format_cast( - self.kv_b_proj.weight.data, ACL_FORMAT_FRACTAL_ND).T + kv_b_proj_weight = torch_npu.npu_format_cast(self.kv_b_proj.weight.data, ACL_FORMAT_FRACTAL_ND).T assert kv_b_proj_weight.shape == ( self.kv_lora_rank, - self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( - f"{kv_b_proj_weight.shape=}, " - f"{self.kv_lora_rank=}, " - f"{self.num_heads=}, " - f"{self.qk_nope_head_dim=}, " - f"{self.v_head_dim=}") + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + ), ( + f"{kv_b_proj_weight.shape=}, " + f"{self.kv_lora_rank=}, " + f"{self.num_heads=}, " + f"{self.qk_nope_head_dim=}, " + f"{self.v_head_dim=}" + ) kv_b_proj_weight = kv_b_proj_weight.view( self.kv_lora_rank, self.num_heads, self.qk_nope_head_dim + self.v_head_dim, ) - W_UK, W_UV = kv_b_proj_weight.split( - [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + W_UK, W_UV = kv_b_proj_weight.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) # Convert from (L, N, V) to (N, L, V) self.W_UV = W_UV.transpose(0, 1).contiguous() @@ -884,97 +861,72 @@ class AscendMLAImpl(MLAAttentionImpl): # Currently mlapo only supports W8A8 quantization in MLA scenario # TODO(whx): modify this limitation when mlapo supports floating point if self.fused_qkv_a_proj is None or not isinstance( - getattr(self.fused_qkv_a_proj.quant_method, 'quant_method', - None), AscendW8A8LinearMethod): + getattr(self.fused_qkv_a_proj.quant_method, "quant_method", None), AscendW8A8LinearMethod + ): self.enable_mlapo = False logger.warning_once( "Currently mlapo only supports W8A8 quantization in MLA scenario." "Some layers in your model are not quantized with W8A8," - "thus mlapo is disabled for these layers.") + "thus mlapo is disabled for these layers." + ) if self.enable_mlapo: self._process_weights_for_fused_mlapo(act_dtype) else: # if mlapo, W_UK_T can't trans nz self.W_UK_T = maybe_trans_nz(self.W_UK_T) - for layer in (self.layer_sharding_kwargs or []): + for layer in self.layer_sharding_kwargs or []: if is_hidden_layer(layer): post_process_after_loading_for_shard_weight_series(layer) def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype): - kv_a_proj_wt = self.fused_qkv_a_proj.weight.data[ - ..., self.q_lora_rank:].contiguous() - q_a_proj_wt = self.fused_qkv_a_proj.weight.data[ - ..., :self.q_lora_rank].contiguous() + kv_a_proj_wt = self.fused_qkv_a_proj.weight.data[..., self.q_lora_rank :].contiguous() # type: ignore[union-attr] + q_a_proj_wt = self.fused_qkv_a_proj.weight.data[..., : self.q_lora_rank].contiguous() # type: ignore[union-attr] kv_a_proj_wt = kv_a_proj_wt.t().contiguous() kv_a_proj_wt = trans_rope_weight(kv_a_proj_wt, self.qk_rope_head_dim) kv_a_proj_wt = kv_a_proj_wt.t().contiguous() wd_qkv = torch.cat((kv_a_proj_wt, q_a_proj_wt), dim=-1) wd_qkv = wd_qkv.t().contiguous() - wd_qkv = transdata(wd_qkv, - block_size=(16, 32)).unsqueeze(0).contiguous() + wd_qkv = transdata(wd_qkv, block_size=(16, 32)).unsqueeze(0).contiguous() self.wd_qkv = torch_npu.npu_format_cast(wd_qkv, 29) - kv_a_proj_deq_scl = self.fused_qkv_a_proj.deq_scale[ - self.q_lora_rank:].contiguous() - q_a_proj_deq_scl = self.fused_qkv_a_proj.deq_scale[:self. - q_lora_rank].contiguous( - ) - kv_a_proj_deq_scl = kv_a_proj_deq_scl.reshape( - self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous() - kv_a_proj_deq_scl = trans_rope_weight(kv_a_proj_deq_scl, - self.qk_rope_head_dim) - kv_a_proj_deq_scl = kv_a_proj_deq_scl.view( - self.kv_lora_rank + self.qk_rope_head_dim).contiguous() - self.deq_scale_qkv = torch.cat((kv_a_proj_deq_scl, q_a_proj_deq_scl), - dim=-1).contiguous() + kv_a_proj_deq_scl = self.fused_qkv_a_proj.deq_scale[self.q_lora_rank :].contiguous() # type: ignore[union-attr] + q_a_proj_deq_scl = self.fused_qkv_a_proj.deq_scale[: self.q_lora_rank].contiguous() # type: ignore[union-attr] + kv_a_proj_deq_scl = kv_a_proj_deq_scl.reshape(self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous() + kv_a_proj_deq_scl = trans_rope_weight(kv_a_proj_deq_scl, self.qk_rope_head_dim) + kv_a_proj_deq_scl = kv_a_proj_deq_scl.view(self.kv_lora_rank + self.qk_rope_head_dim).contiguous() + self.deq_scale_qkv = torch.cat((kv_a_proj_deq_scl, q_a_proj_deq_scl), dim=-1).contiguous() - kv_a_proj_qt_bias = self.fused_qkv_a_proj.quant_bias[ - self.q_lora_rank:].contiguous() - q_a_proj_qt_bias = self.fused_qkv_a_proj.quant_bias[:self. - q_lora_rank].contiguous( - ) - kv_a_proj_qt_bias = kv_a_proj_qt_bias.reshape( - self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous() - kv_a_proj_qt_bias = trans_rope_weight(kv_a_proj_qt_bias, - self.qk_rope_head_dim) - kv_a_proj_qt_bias = kv_a_proj_qt_bias.view( - self.kv_lora_rank + self.qk_rope_head_dim).contiguous() - self.quant_bias_qkv = torch.cat((kv_a_proj_qt_bias, q_a_proj_qt_bias), - dim=-1).contiguous() + kv_a_proj_qt_bias = self.fused_qkv_a_proj.quant_bias[self.q_lora_rank :].contiguous() # type: ignore[union-attr] + q_a_proj_qt_bias = self.fused_qkv_a_proj.quant_bias[: self.q_lora_rank].contiguous() # type: ignore[union-attr] + kv_a_proj_qt_bias = kv_a_proj_qt_bias.reshape(self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous() + kv_a_proj_qt_bias = trans_rope_weight(kv_a_proj_qt_bias, self.qk_rope_head_dim) + kv_a_proj_qt_bias = kv_a_proj_qt_bias.view(self.kv_lora_rank + self.qk_rope_head_dim).contiguous() + self.quant_bias_qkv = torch.cat((kv_a_proj_qt_bias, q_a_proj_qt_bias), dim=-1).contiguous() wu_q = self.q_proj.weight.data - wu_q = wu_q.t().reshape(self.num_heads, - self.qk_nope_head_dim + self.qk_rope_head_dim, - -1) + wu_q = wu_q.t().reshape(self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1) wu_q = trans_rope_weight(wu_q, self.qk_rope_head_dim) - wu_q = wu_q.reshape( - self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim), - -1) + wu_q = wu_q.reshape(self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim), -1) wu_q = transdata(wu_q, block_size=(16, 32)).unsqueeze(0).contiguous() self.wu_q = torch_npu.npu_format_cast(wu_q, 29) qb_deq_scl = self.q_proj.deq_scale.data - qb_deq_scl = qb_deq_scl.reshape( - self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1) + qb_deq_scl = qb_deq_scl.reshape(self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1) qb_deq_scl = trans_rope_weight(qb_deq_scl, self.qk_rope_head_dim) - self.qb_deq_scl = qb_deq_scl.reshape( - self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim)) + self.qb_deq_scl = qb_deq_scl.reshape(self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim)) qb_qt_bias = self.q_proj.quant_bias.data - qb_qt_bias = qb_qt_bias.reshape( - self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1) + qb_qt_bias = qb_qt_bias.reshape(self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1) qb_qt_bias = trans_rope_weight(qb_qt_bias, self.qk_rope_head_dim) - self.qb_qt_bias = qb_qt_bias.reshape( - self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim)) + self.qb_qt_bias = qb_qt_bias.reshape(self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim)) device = self.q_proj.weight.device - self.gamma1 = self.q_a_layernorm.weight.data - self.beta1 = torch.zeros_like(self.gamma1) if ( - _bias := self.q_a_layernorm.bias) is None else _bias.data - self.gamma2 = self.kv_a_layernorm.weight.data - self.quant_scale0 = self.fused_qkv_a_proj.input_scale.data - self.quant_offset0 = self.fused_qkv_a_proj.input_offset.data + self.gamma1 = self.q_a_layernorm.weight.data # type: ignore[union-attr] + self.beta1 = torch.zeros_like(self.gamma1) if (_bias := self.q_a_layernorm.bias) is None else _bias.data # type: ignore[union-attr] + self.gamma2 = self.kv_a_layernorm.weight.data # type: ignore[union-attr] + self.quant_scale0 = self.fused_qkv_a_proj.input_scale.data # type: ignore[union-attr] + self.quant_offset0 = self.fused_qkv_a_proj.input_offset.data # type: ignore[union-attr] self.quant_scale1 = self.q_proj.input_scale.data self.quant_offset1 = self.q_proj.input_offset.data self.ctkv_scale = torch.tensor([1], dtype=act_dtype, device=device) @@ -983,19 +935,20 @@ class AscendMLAImpl(MLAAttentionImpl): # On KV consumers (decode-only) MLAPO uses the transformed weights built above; # the original fused_qkv_a_proj/q_proj weights and quant params are no longer # referenced, so drop them to save memory. - if self.vllm_config.kv_transfer_config is not None and \ - self.vllm_config.kv_transfer_config.is_kv_consumer and \ - self.vllm_config.scheduler_config.max_num_batched_tokens <= MLAPO_MAX_SUPPORTED_TOKENS: - self.fused_qkv_a_proj.weight = None - self.fused_qkv_a_proj.deq_scale = None - self.fused_qkv_a_proj.quant_bias = None + if ( + self.vllm_config.kv_transfer_config is not None + and self.vllm_config.kv_transfer_config.is_kv_consumer + and self.vllm_config.scheduler_config.max_num_batched_tokens <= MLAPO_MAX_SUPPORTED_TOKENS + ): + self.fused_qkv_a_proj.weight = None # type: ignore[union-attr] + self.fused_qkv_a_proj.deq_scale = None # type: ignore[union-attr] + self.fused_qkv_a_proj.quant_bias = None # type: ignore[union-attr] self.q_proj.weight = None self.q_proj.deq_scale = None self.q_proj.quant_bias = None torch.npu.empty_cache() - def get_context_seq_len_npu(self, index: int, - attn_metadata: AscendMLAMetadata): + def get_context_seq_len_npu(self, index: int, attn_metadata: AscendMLAMetadata): prefill_metadata = attn_metadata.prefill assert prefill_metadata is not None assert prefill_metadata.chunked_context is not None @@ -1018,7 +971,7 @@ class AscendMLAImpl(MLAAttentionImpl): self, q_nope: torch.Tensor, q_pe: torch.Tensor, - kv_c_and_k_pe_cache: Tuple[torch.Tensor], + kv_c_and_k_pe_cache: tuple[torch.Tensor], rope_dim: int, attn_metadata: AscendMLAMetadata, prefix_output: torch.Tensor, @@ -1031,8 +984,7 @@ class AscendMLAImpl(MLAAttentionImpl): iters = len(prefill_metadata.chunked_context.seq_tot) - current_seq_len = torch.tensor(prefill_metadata.query_lens, - dtype=torch.int32) + current_seq_len = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32) cache_kv_c = kv_c_and_k_pe_cache[0] cache_k_pe = kv_c_and_k_pe_cache[1] num_heads = cache_k_pe.size(2) @@ -1040,21 +992,11 @@ class AscendMLAImpl(MLAAttentionImpl): for i in range(iters): toks = prefill_metadata.chunked_context.seq_tot[i] # chunk_seq_lens will be padded when pcp&dcp - context_seq_len = prefill_metadata.chunked_context.chunk_seq_lens[ - i] + context_seq_len = prefill_metadata.chunked_context.chunk_seq_lens[i] seq_len = torch.stack([current_seq_len, context_seq_len]) - context_seq_len_npu = self.get_context_seq_len_npu( - i, attn_metadata) - kv_c_normed = torch.empty(toks, - num_heads, - latent_kv_dim, - dtype=q_nope.dtype, - device=q_nope.device) - k_pe = torch.empty(toks, - num_heads, - rope_dim, - dtype=q_nope.dtype, - device=q_nope.device) + context_seq_len_npu = self.get_context_seq_len_npu(i, attn_metadata) + kv_c_normed = torch.empty(toks, num_heads, latent_kv_dim, dtype=q_nope.dtype, device=q_nope.device) + k_pe = torch.empty(toks, num_heads, rope_dim, dtype=q_nope.dtype, device=q_nope.device) torch_npu.atb.npu_paged_cache_load( cache_kv_c, @@ -1073,10 +1015,8 @@ class AscendMLAImpl(MLAAttentionImpl): toks=toks, ) kv_c_normed = kv_c_normed.squeeze() - kv_nope = self.kv_b_proj(kv_c_normed)[0].view( - -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - k_nope, v = kv_nope \ - .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + kv_nope = self.kv_b_proj(kv_c_normed)[0].view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) k_pe = k_pe.expand((*k_nope.shape[:-1], -1)) mask = attn_metadata.attn_mask @@ -1098,7 +1038,8 @@ class AscendMLAImpl(MLAAttentionImpl): input_layout="type_bsnd", calc_type="calc_type_default", output=prefix_output, - softmax_lse=prefix_lse) + softmax_lse=prefix_lse, + ) return prefix_output, prefix_lse def _forward_prefill( @@ -1108,45 +1049,39 @@ class AscendMLAImpl(MLAAttentionImpl): k_nope: torch.Tensor, k_pe: torch.Tensor, value: torch.Tensor, - kv_c_and_k_pe_cache: Tuple[torch.Tensor], + kv_c_and_k_pe_cache: tuple[torch.Tensor], attn_metadata: AscendMLAMetadata, ) -> torch.Tensor: assert attn_metadata.prefill is not None assert len(kv_c_and_k_pe_cache) > 1 num_tokens = q_nope.size(0) - attn_output = torch.empty(num_tokens, - self.num_heads, - self.v_head_dim, - dtype=q_nope.dtype, - device=q_nope.device) - attn_lse = torch.empty(self.num_heads, - num_tokens, - dtype=torch.float32, - device=q_nope.device) - torch_npu.atb.npu_ring_mla(q_nope=q_nope, - q_rope=q_pe, - k_nope=k_nope, - k_rope=k_pe, - value=value, - mask=attn_metadata.attn_mask, - seqlen=attn_metadata.prefill.query_lens, - head_num=self.num_heads, - kv_head_num=self.num_heads, - pre_out=None, - prev_lse=None, - qk_scale=self.scale, - kernel_type="kernel_type_high_precision", - mask_type="mask_type_triu", - input_layout="type_bsnd", - calc_type="calc_type_first_ring", - output=attn_output, - softmax_lse=attn_lse) + attn_output = torch.empty(num_tokens, self.num_heads, self.v_head_dim, dtype=q_nope.dtype, device=q_nope.device) + attn_lse = torch.empty(self.num_heads, num_tokens, dtype=torch.float32, device=q_nope.device) + torch_npu.atb.npu_ring_mla( + q_nope=q_nope, + q_rope=q_pe, + k_nope=k_nope, + k_rope=k_pe, + value=value, + mask=attn_metadata.attn_mask, + seqlen=attn_metadata.prefill.query_lens, + head_num=self.num_heads, + kv_head_num=self.num_heads, + pre_out=None, + prev_lse=None, + qk_scale=self.scale, + kernel_type="kernel_type_high_precision", + mask_type="mask_type_triu", + input_layout="type_bsnd", + calc_type="calc_type_first_ring", + output=attn_output, + softmax_lse=attn_lse, + ) attn_output, attn_lse = self._compute_prefill_context( - q_nope, q_pe, kv_c_and_k_pe_cache, self.qk_rope_head_dim, - attn_metadata, attn_output, attn_lse) + q_nope, q_pe, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse + ) - attn_output = attn_output.reshape( - [num_tokens, self.num_heads * self.v_head_dim]) + attn_output = attn_output.reshape([num_tokens, self.num_heads * self.v_head_dim]) return attn_output def exec_kv_decode( @@ -1154,25 +1089,24 @@ class AscendMLAImpl(MLAAttentionImpl): kv_no_split: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, - kv_cache: Tuple, + kv_cache: tuple, slots: torch.Tensor, ): B = kv_no_split.shape[0] N = self.num_kv_heads S = 1 # npu_kv_rmsnorm_rope_cache needs [B, N, S, D] - kv_no_split = kv_no_split.view( - B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) + kv_no_split = kv_no_split.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) cache_mode = "PA_NZ" if self.enable_kv_nz else "PA" k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( kv_no_split, - self.kv_a_layernorm.weight, + self.kv_a_layernorm.weight, # type: ignore[union-attr] cos, sin, slots.to(torch.int64), kv_cache[1], kv_cache[0], - epsilon=self.kv_a_layernorm.variance_epsilon, + epsilon=self.kv_a_layernorm.variance_epsilon, # type: ignore[union-attr] cache_mode=cache_mode, ) return k_pe, k_nope @@ -1182,25 +1116,24 @@ class AscendMLAImpl(MLAAttentionImpl): kv_no_split: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, - kv_cache: Tuple, + kv_cache: tuple, slots: torch.Tensor, ): B = kv_no_split.shape[0] N = self.num_kv_heads S = 1 # npu_kv_rmsnorm_rope_cache needs [B, N, S, D] - kv_no_split = kv_no_split.view( - B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) + kv_no_split = kv_no_split.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) cache_mode = "PA" _, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache( kv_no_split, - self.kv_a_layernorm.weight, + self.kv_a_layernorm.weight, # type: ignore[union-attr] cos, sin, slots.to(torch.int64), kv_cache[1], kv_cache[0], - epsilon=self.kv_a_layernorm.variance_epsilon, + epsilon=self.kv_a_layernorm.variance_epsilon, # type: ignore[union-attr] cache_mode=cache_mode, is_output_kv=True, ) @@ -1235,24 +1168,26 @@ class AscendMLAImpl(MLAAttentionImpl): actual_seq_lengths = None if self.enable_kv_nz: nz_fmt_last_dim = 16 - k_nope = k_nope.view(-1, self.num_kv_heads, - self.kv_lora_rank // nz_fmt_last_dim, - block_size, nz_fmt_last_dim) - k_pe = k_pe.view(-1, self.num_kv_heads, - self.qk_rope_head_dim // nz_fmt_last_dim, - block_size, nz_fmt_last_dim) + k_nope = k_nope.view( + -1, self.num_kv_heads, self.kv_lora_rank // nz_fmt_last_dim, block_size, nz_fmt_last_dim + ) + k_pe = k_pe.view( + -1, self.num_kv_heads, self.qk_rope_head_dim // nz_fmt_last_dim, block_size, nz_fmt_last_dim + ) else: - k_nope = k_nope.view(-1, self.num_kv_heads, block_size, - self.kv_lora_rank) - k_pe = k_pe.view(-1, self.num_kv_heads, block_size, - self.qk_rope_head_dim) + k_nope = k_nope.view(-1, self.num_kv_heads, block_size, self.kv_lora_rank) + k_pe = k_pe.view(-1, self.num_kv_heads, block_size, self.qk_rope_head_dim) attn_output_shape: tuple | None = None - if attn_metadata.attn_state in [ + if ( + attn_metadata.attn_state + in [ AscendAttentionState.SpecDecoding, AscendAttentionState.ChunkedPrefill, AscendAttentionState.DecodeOnly, - ] and self.speculative_config is not None: + ] + and self.speculative_config is not None + ): # The right part layout indicates the layout of the attention # output. It is set to NTD to avoid the need for a transpose # operation after attention. @@ -1272,34 +1207,31 @@ class AscendMLAImpl(MLAAttentionImpl): if self.enable_kv_nz: # Input shape: [num_tokens, seq_len, num_heads, dim] input_layout = "BSND_NBSD" - q_nope = q_nope.view(num_tokens, 1, self.num_heads, - -1).contiguous() + q_nope = q_nope.view(num_tokens, 1, self.num_heads, -1).contiguous() q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1) else: # Input shape: [num_tokens, num_heads, seq_len, dim] input_layout = "BNSD_NBSD" - q_nope = q_nope.view(num_tokens, self.num_heads, 1, - -1).contiguous() + q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1).contiguous() q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1) # Output shape: [num_heads, num_tokens, seq_len, dim] - attn_output_shape = (self.num_heads, num_tokens, 1, - self.kv_lora_rank) + attn_output_shape = (self.num_heads, num_tokens, 1, self.kv_lora_rank) sparse_mode = 0 attn_mask = None common_kwargs = { - 'query_rope': q_pe, - 'key_rope': k_pe, - 'num_heads': self.num_heads, - 'num_key_value_heads': self.num_kv_heads, - 'input_layout': input_layout, - 'atten_mask': attn_mask, - 'sparse_mode': sparse_mode, - 'scale': self.scale, - 'antiquant_mode': 0, - 'antiquant_scale': None, - 'block_table': decode_meta.block_table, - 'block_size': block_size, + "query_rope": q_pe, + "key_rope": k_pe, + "num_heads": self.num_heads, + "num_key_value_heads": self.num_kv_heads, + "input_layout": input_layout, + "atten_mask": attn_mask, + "sparse_mode": sparse_mode, + "scale": self.scale, + "antiquant_mode": 0, + "antiquant_scale": None, + "block_table": decode_meta.block_table, + "block_size": block_size, "actual_seq_lengths": actual_seq_lengths, "actual_seq_lengths_kv": decode_meta.seq_lens_list, } @@ -1319,49 +1251,52 @@ class AscendMLAImpl(MLAAttentionImpl): workspace = graph_params.workspaces.get(num_tokens) if workspace is None: workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace( - q_nope, k_nope, k_nope, **common_kwargs) + q_nope, k_nope, k_nope, **common_kwargs + ) if forward_context.is_draft_model: update_draft_graph_params_workspaces(num_tokens, workspace) else: update_graph_params_workspaces(num_tokens, workspace) - attn_output = torch.empty(attn_output_shape, - dtype=q_nope.dtype, - device=q_nope.device) - softmax_lse = torch.empty(num_tokens, - dtype=q_nope.dtype, - device=q_nope.device) + attn_output = torch.empty(attn_output_shape, dtype=q_nope.dtype, device=q_nope.device) + softmax_lse = torch.empty(num_tokens, dtype=q_nope.dtype, device=q_nope.device) graph_params.attn_params[num_tokens].append( - (weak_ref_tensors(q_nope), weak_ref_tensors(k_nope), - weak_ref_tensors(q_pe), weak_ref_tensors(k_pe), - self.num_heads, self.num_kv_heads, input_layout, - weak_ref_tensors(attn_mask) if attn_mask is not None else - None, sparse_mode, self.scale, decode_meta.block_table, - block_size, decode_meta.seq_lens_list, actual_seq_lengths, - weak_ref_tensors(attn_output), weak_ref_tensors(softmax_lse))) + ( + weak_ref_tensors(q_nope), + weak_ref_tensors(k_nope), + weak_ref_tensors(q_pe), + weak_ref_tensors(k_pe), + self.num_heads, + self.num_kv_heads, + input_layout, + weak_ref_tensors(attn_mask) if attn_mask is not None else None, + sparse_mode, + self.scale, + decode_meta.block_table, + block_size, + decode_meta.seq_lens_list, + actual_seq_lengths, + weak_ref_tensors(attn_output), + weak_ref_tensors(softmax_lse), + ) + ) torch.npu.graph_task_group_begin(stream) torch_npu.npu_fused_infer_attention_score.out( - q_nope, - k_nope, - k_nope, - **common_kwargs, - workspace=workspace, - out=[attn_output, softmax_lse]) + q_nope, k_nope, k_nope, **common_kwargs, workspace=workspace, out=[attn_output, softmax_lse] + ) handle = torch.npu.graph_task_group_end(stream) graph_params.handles[num_tokens].append(handle) else: - attn_output, _ = torch_npu.npu_fused_infer_attention_score( - q_nope, k_nope, k_nope, **common_kwargs) + attn_output, _ = torch_npu.npu_fused_infer_attention_score(q_nope, k_nope, k_nope, **common_kwargs) return self._v_up_proj(attn_output) def reorg_decode_q(self, decode_q_nope, decode_q_pe): return decode_q_nope, decode_q_pe - def _mla_preprocess_only_decode(self, hidden_states, kv_cache, - attn_metadata): + def _mla_preprocess_only_decode(self, hidden_states, kv_cache, attn_metadata): bsz = attn_metadata.num_decode_tokens hidden_states = hidden_states[:bsz] @@ -1371,14 +1306,12 @@ class AscendMLAImpl(MLAAttentionImpl): decode_k_nope, decode_k_pe = kv_cache[0], kv_cache[1] decode_q_nope = torch.empty( - (hidden_states.shape[0], self.W_UK_T.shape[0], - decode_k_nope.shape[-1]), + (hidden_states.shape[0], self.W_UK_T.shape[0], decode_k_nope.shape[-1]), dtype=hidden_states.dtype, device=hidden_states.device, ) decode_q_pe = torch.empty( - (hidden_states.shape[0], self.W_UK_T.shape[0], - decode_k_pe.shape[-1]), + (hidden_states.shape[0], self.W_UK_T.shape[0], decode_k_pe.shape[-1]), dtype=hidden_states.dtype, device=hidden_states.device, ) @@ -1413,68 +1346,55 @@ class AscendMLAImpl(MLAAttentionImpl): q_out1=decode_q_pe, kv_cache_out1=decode_k_pe, enable_inner_out=False, - inner_out=torch.tensor([], device=hidden_states.device)) - decode_q_nope = decode_q_nope.view(bsz, self.num_heads, - self.kv_lora_rank) + inner_out=torch.tensor([], device=hidden_states.device), + ) + decode_q_nope = decode_q_nope.view(bsz, self.num_heads, self.kv_lora_rank) decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1) - decode_q_nope, decode_q_pe = self.reorg_decode_q( - decode_q_nope, decode_q_pe) + decode_q_nope, decode_q_pe = self.reorg_decode_q(decode_q_nope, decode_q_pe) - decode_preprocess_res = DecodeMLAPreprocessResult( - decode_q_nope, decode_q_pe, decode_k_nope, decode_k_pe) + decode_preprocess_res = DecodeMLAPreprocessResult(decode_q_nope, decode_q_pe, decode_k_nope, decode_k_pe) return decode_preprocess_res, None - def mla_preprocess_prefill(self, q_c, kv_no_split, kv_cache, - attn_metadata): + def mla_preprocess_prefill(self, q_c, kv_no_split, kv_cache, attn_metadata): num_decode_tokens = attn_metadata.num_decode_tokens num_actual_tokens = attn_metadata.num_actual_tokens prefill_kv_no_split = kv_no_split[num_decode_tokens:num_actual_tokens] prefill_q_c = q_c[num_decode_tokens:num_actual_tokens] - prefill_q = self.q_proj(prefill_q_c)[0] \ - .view(-1, self.num_heads, self.qk_head_dim) - prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:] - prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim] + prefill_q = self.q_proj(prefill_q_c)[0].view(-1, self.num_heads, self.qk_head_dim) + prefill_q_pe = prefill_q[..., self.qk_nope_head_dim :] + prefill_q_nope = prefill_q[..., : self.qk_nope_head_dim] cos = attn_metadata.prefill.cos sin = attn_metadata.prefill.sin - prefill_slots = attn_metadata.slot_mapping[ - num_decode_tokens:num_actual_tokens] + prefill_slots = attn_metadata.slot_mapping[num_decode_tokens:num_actual_tokens] prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin) if self.is_kv_producer: attn_metadata.reshape_cache_event = torch.npu.Event() - prefill_k_pe, prefill_k_c_normed = self.exec_kv_prefill( - prefill_kv_no_split, cos, sin, kv_cache, prefill_slots) + prefill_k_pe, prefill_k_c_normed = self.exec_kv_prefill(prefill_kv_no_split, cos, sin, kv_cache, prefill_slots) if self.is_kv_producer: attn_metadata.reshape_cache_event.record() - prefill_k_nope, prefill_value = self.kv_b_proj( - prefill_k_c_normed)[0].view( - -1, self.num_heads, - self.qk_nope_head_dim + self.v_head_dim).split( - [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - prefill_k_pe = prefill_k_pe.view(prefill_q_c.shape[0], - self.num_kv_heads, -1) + prefill_k_nope, prefill_value = ( + self.kv_b_proj(prefill_k_c_normed)[0] + .view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + ) + prefill_k_pe = prefill_k_pe.view(prefill_q_c.shape[0], self.num_kv_heads, -1) prefill_k_pe = prefill_k_pe.expand((*prefill_k_nope.shape[:-1], -1)) - return PrefillMLAPreprocessResult(prefill_q_nope, prefill_q_pe, - prefill_k_nope, prefill_k_pe, - prefill_value) + return PrefillMLAPreprocessResult(prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe, prefill_value) def mla_preprocess_decode(self, q_c, kv_no_split, kv_cache, attn_metadata): num_decode_tokens = attn_metadata.num_decode_tokens decode_q_c = q_c[:num_decode_tokens] cos = attn_metadata.decode.cos sin = attn_metadata.decode.sin - decode_ql_nope, decode_q_pe = \ - self._q_proj_and_k_up_proj(decode_q_c) + decode_ql_nope, decode_q_pe = self._q_proj_and_k_up_proj(decode_q_c) decode_q_pe = self.rope_single(decode_q_pe, cos, sin) decode_slots = attn_metadata.slot_mapping[:num_decode_tokens:1] decode_kv_no_split = kv_no_split[:num_decode_tokens] - decode_k_pe, decode_k_nope = self.exec_kv_decode( - decode_kv_no_split, cos, sin, kv_cache, decode_slots) - return DecodeMLAPreprocessResult(decode_ql_nope, decode_q_pe, - decode_k_nope, decode_k_pe) + decode_k_pe, decode_k_nope = self.exec_kv_decode(decode_kv_no_split, cos, sin, kv_cache, decode_slots) + return DecodeMLAPreprocessResult(decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe) - def _mla_preprocess(self, layer_name, hidden_states, kv_cache, - attn_metadata, need_gather_q_kv): + def _mla_preprocess(self, layer_name, hidden_states, kv_cache, attn_metadata, need_gather_q_kv): # MLA Preprocess: # 1. Perform fused_qkv_a_proj and q_a_layernorm to obtain q_c and kv_no_split # or @@ -1487,28 +1407,26 @@ class AscendMLAImpl(MLAAttentionImpl): has_decode = attn_metadata.num_decodes > 0 has_prefill = attn_metadata.num_prefills > 0 if self.fused_qkv_a_proj is not None: - maybe_npu_prefetch(inputs=self.fused_qkv_a_proj.weight, - dependency=hidden_states, - enabled=self.enable_prefetch) + maybe_npu_prefetch( + inputs=self.fused_qkv_a_proj.weight, dependency=hidden_states, enabled=self.enable_prefetch + ) qkv_lora = self.fused_qkv_a_proj(hidden_states)[0] q_c, kv_no_split = qkv_lora.split( [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1, ) - q_c = self.q_a_layernorm(q_c) + q_c = self.q_a_layernorm(q_c) # type: ignore[misc] # allgather need contiguous data kv_no_split = kv_no_split.contiguous() else: q_c = hidden_states - kv_no_split = self.kv_a_proj_with_mqa(hidden_states)[0] + kv_no_split = self.kv_a_proj_with_mqa(hidden_states)[0] # type: ignore[misc] # Process for Flash Comm V1 - q_c = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( - q_c.contiguous(), need_gather_q_kv) - kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( - kv_no_split.contiguous(), need_gather_q_kv) + q_c = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(q_c.contiguous(), need_gather_q_kv) + kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(kv_no_split.contiguous(), need_gather_q_kv) - for layer in (self.layer_sharding_kwargs or []): + for layer in self.layer_sharding_kwargs or []: if is_hidden_layer(layer): reach_layer_for_shard_weight_series(layer) @@ -1518,12 +1436,10 @@ class AscendMLAImpl(MLAAttentionImpl): wait_for_kv_layer_from_connector(layer_name) # Preprocess for decode tokens if has_decode: - decode_preprocess_res = self.mla_preprocess_decode( - q_c, kv_no_split, kv_cache, attn_metadata) + decode_preprocess_res = self.mla_preprocess_decode(q_c, kv_no_split, kv_cache, attn_metadata) # Preprocess for prefill tokens if has_prefill: - prefill_preprocess_res = self.mla_preprocess_prefill( - q_c, kv_no_split, kv_cache, attn_metadata) + prefill_preprocess_res = self.mla_preprocess_prefill(q_c, kv_no_split, kv_cache, attn_metadata) return decode_preprocess_res, prefill_preprocess_res def get_num_actual_tokens(self, attn_metadata: M): @@ -1533,54 +1449,56 @@ class AscendMLAImpl(MLAAttentionImpl): self, layer_name, hidden_states: torch.Tensor, # query in unified attn - kv_cache: Tuple[torch.Tensor], + kv_cache: tuple[torch.Tensor], attn_metadata: M, need_gather_q_kv: bool = False, - output: Optional[torch.Tensor] = None, + output: torch.Tensor | None = None, ) -> torch.Tensor: assert output is not None, "Output tensor must be provided." if attn_metadata is None: # Profiling run. - for layer in (self.layer_sharding_kwargs or []): + for layer in self.layer_sharding_kwargs or []: if is_hidden_layer(layer): reach_layer_for_shard_weight_series(layer) return output.fill_(0) forward_context = get_forward_context() num_actual_tokens = self.get_num_actual_tokens(attn_metadata) - assert attn_metadata.num_decodes is not None and \ - attn_metadata.num_prefills is not None and \ - attn_metadata.num_decode_tokens is not None + assert ( + attn_metadata.num_decodes is not None + and attn_metadata.num_prefills is not None + and attn_metadata.num_decode_tokens is not None + ) has_prefill = attn_metadata.num_prefills > 0 num_decode_tokens = attn_metadata.num_decode_tokens # Inputs and outputs may be padded for CUDA graphs output_padded = output - o_proj_input_shape = (forward_context.num_tokens, - self.num_heads * self.v_head_dim) - o_proj_input = torch.empty(o_proj_input_shape, - dtype=hidden_states.dtype, - device=hidden_states.device) + o_proj_input_shape = (forward_context.num_tokens, self.num_heads * self.v_head_dim) + o_proj_input = torch.empty(o_proj_input_shape, dtype=hidden_states.dtype, device=hidden_states.device) # MLA Preprocess - if self.enable_mlapo and \ - attn_metadata.num_decode_tokens <= MLAPO_MAX_SUPPORTED_TOKENS: + if self.enable_mlapo and attn_metadata.num_decode_tokens <= MLAPO_MAX_SUPPORTED_TOKENS: hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( - hidden_states.contiguous(), need_gather_q_kv) + hidden_states.contiguous(), need_gather_q_kv + ) decode_preprocess_res, prefill_preprocess_res = self._mla_preprocess_only_decode( - hidden_states, kv_cache, attn_metadata) + hidden_states, kv_cache, attn_metadata + ) else: decode_preprocess_res, prefill_preprocess_res = self._mla_preprocess( - layer_name, hidden_states, kv_cache, attn_metadata, - need_gather_q_kv) + layer_name, hidden_states, kv_cache, attn_metadata, need_gather_q_kv + ) if decode_preprocess_res is not None: # MLA Preprocess for decoding - output_decode = self._forward_decode(decode_preprocess_res.ql_nope, - decode_preprocess_res.q_pe, - decode_preprocess_res.k_nope, - decode_preprocess_res.k_pe, - kv_cache[0].shape[1], - attn_metadata) + output_decode = self._forward_decode( + decode_preprocess_res.ql_nope, + decode_preprocess_res.q_pe, + decode_preprocess_res.k_nope, + decode_preprocess_res.k_pe, + kv_cache[0].shape[1], + attn_metadata, + ) o_proj_input[:num_decode_tokens] = output_decode @@ -1589,21 +1507,26 @@ class AscendMLAImpl(MLAAttentionImpl): # otherwise it may affect the accuracy # TODO: use an elegant way to overlap output_prefill = self._forward_prefill( - prefill_preprocess_res.q_nope, prefill_preprocess_res.q_pe, - prefill_preprocess_res.k_nope, prefill_preprocess_res.k_pe, - prefill_preprocess_res.value, kv_cache, attn_metadata) + prefill_preprocess_res.q_nope, + prefill_preprocess_res.q_pe, + prefill_preprocess_res.k_nope, + prefill_preprocess_res.k_pe, + prefill_preprocess_res.value, + kv_cache, + attn_metadata, + ) o_proj_input[num_decode_tokens:num_actual_tokens] = output_prefill # O proj MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 - maybe_npu_prefetch(inputs=self.o_proj.weight, - dependency=o_proj_input, - max_size=MAX_O_PROJ_PREFETCH_SIZE, - enabled=self.enable_prefetch) + maybe_npu_prefetch( + inputs=self.o_proj.weight, + dependency=o_proj_input, + max_size=MAX_O_PROJ_PREFETCH_SIZE, + enabled=self.enable_prefetch, + ) - output[...] = self.o_proj(o_proj_input, - is_prefill=prefill_preprocess_res - is not None)[0] + output[...] = self.o_proj(o_proj_input, is_prefill=prefill_preprocess_res is not None)[0] del o_proj_input diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index 8575fa2b..5a01b1ca 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -1,19 +1,17 @@ from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional, Tuple, Type, TypeVar +from typing import TYPE_CHECKING, TypeVar import torch import torch_npu import vllm.envs as envs_vllm from torch import nn -from vllm.config import CUDAGraphMode, VllmConfig, get_current_vllm_config from vllm.config import VllmConfig, get_current_vllm_config from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group from vllm.forward_context import get_forward_context from vllm.logger import logger from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm.triton_utils import HAS_TRITON -from vllm.v1.attention.backend import ( # type: ignore - AttentionBackend, AttentionCGSupport, MLAAttentionImpl) +from vllm.v1.attention.backend import AttentionBackend, AttentionCGSupport, MLAAttentionImpl # type: ignore from vllm.v1.attention.backends.mla.common import MLACommonMetadataBuilder from vllm.v1.kv_cache_interface import AttentionSpec @@ -22,22 +20,33 @@ from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE, MLAPO_MAX_SUPPORTED_TOKENS -from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, - ascend_chunked_prefill_workspace_size, - maybe_save_kv_layer_to_connector, - trans_rope_weight, transdata, - wait_for_kv_layer_from_connector) +from vllm_ascend.attention.utils import ( + AscendCommonAttentionMetadata, + ascend_chunked_prefill_workspace_size, + maybe_save_kv_layer_to_connector, + trans_rope_weight, + transdata, + wait_for_kv_layer_from_connector, +) from vllm_ascend.distributed.utils import all_gather_async from vllm_ascend.ops.layer_shard_linear import ( - is_hidden_layer, post_process_after_loading_for_shard_weight_series, + is_hidden_layer, + post_process_after_loading_for_shard_weight_series, reach_layer_for_shard_weight_series, - register_all_layers_to_shard_weight_series) + register_all_layers_to_shard_weight_series, +) from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla from vllm_ascend.ops.triton.rope import rope_forward_triton from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch from vllm_ascend.quantization.methods import AscendW8A8LinearMethod -from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, _round_up, dispose_layer, - enable_dsa_cp, enable_dsa_cp_with_layer_shard, maybe_trans_nz) +from vllm_ascend.utils import ( + ACL_FORMAT_FRACTAL_ND, + _round_up, + dispose_layer, + enable_dsa_cp, + enable_dsa_cp_with_layer_shard, + maybe_trans_nz, +) from vllm_ascend.worker.npu_input_batch import NPUInputBatch if TYPE_CHECKING: @@ -48,7 +57,6 @@ BMM_TRANS_MAX_SUPPORTED_TOKENS = 1024 class AscendSFABackend(AttentionBackend): - accept_output_buffer: bool = True @staticmethod @@ -63,12 +71,11 @@ class AscendSFABackend(AttentionBackend): return AscendSFAMetadataBuilder @staticmethod - def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int, - head_size: int) -> tuple[int, ...]: + def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int, head_size: int) -> tuple[int, ...]: return (num_blocks, block_size, num_kv_heads, head_size) @staticmethod - def get_impl_cls() -> Type["AscendSFAImpl"]: + def get_impl_cls() -> type["AscendSFAImpl"]: return AscendSFAImpl @@ -91,6 +98,7 @@ class AscendSFAMetadata: NOTE: Please read the comment at the top of the file before trying to understand this class """ + # NOTE(sang): Definition of context_len, query_len, and seq_len. # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| @@ -109,11 +117,11 @@ class AscendSFAMetadata: # For logging. num_input_tokens: int = 0 # Number of tokens including padding. # The dimension of the attention heads - head_dim: Optional[int] = None + head_dim: int | None = None attn_mask: torch.Tensor = None # chunked prefill by default if no attn_states passed attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill - dsa_cp_context: Optional[DSACPContext] = None + dsa_cp_context: DSACPContext | None = None reshape_cache_event: torch.npu.Event = None @@ -136,37 +144,38 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]): supports_dcp_with_varlen: bool = False, ): super().__init__( - kv_cache_spec, layer_names, vllm_config, device, + kv_cache_spec, + layer_names, + vllm_config, + device, metadata_cls if metadata_cls is not None else AscendSFAMetadata, - supports_dcp_with_varlen) + supports_dcp_with_varlen, + ) self.block_size = vllm_config.cache_config.block_size - self.max_blocks = (vllm_config.model_config.max_model_len + - self.block_size - 1) // self.block_size + self.max_blocks = (vllm_config.model_config.max_model_len + self.block_size - 1) // self.block_size self.speculative_config = vllm_config.speculative_config self.decode_threshold = 1 if self.speculative_config: spec_token_num = self.speculative_config.num_speculative_tokens self.decode_threshold += spec_token_num - assert self.decode_threshold <= 16, f"decode_threshold exceeded \ + assert self.decode_threshold <= 16, ( + f"decode_threshold exceeded \ npu_fused_infer_attention_score TND layout's limit of 16, \ got {self.decode_threshold}" + ) self.attn_mask_builder = AttentionMaskBuilder(self.device) self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim self.enable_dsa_cp = enable_dsa_cp() max_num_reqs = vllm_config.scheduler_config.max_num_seqs - self.actual_seq_lengths_query = torch.zeros(max_num_reqs + 1, - dtype=torch.int32, - device=device) - self.actual_seq_lengths_key = torch.empty_like( - self.actual_seq_lengths_query) + self.actual_seq_lengths_query = torch.zeros(max_num_reqs + 1, dtype=torch.int32, device=device) + self.actual_seq_lengths_key = torch.empty_like(self.actual_seq_lengths_query) @staticmethod - def determine_chunked_prefill_workspace_size( - vllm_config: VllmConfig) -> int: + def determine_chunked_prefill_workspace_size(vllm_config: VllmConfig) -> int: return ascend_chunked_prefill_workspace_size(vllm_config) @classmethod @@ -179,8 +188,7 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]): # @override omitted only because of mypy limitation due to type variable. return AttentionCGSupport.UNIFORM_BATCH - def reorder_batch(self, input_batch: "NPUInputBatch", - scheduler_output: "SchedulerOutput") -> bool: + def reorder_batch(self, input_batch: "NPUInputBatch", scheduler_output: "SchedulerOutput") -> bool: # No need to reorder for Ascend SFA return False @@ -196,11 +204,9 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]): block_table = common_attn_metadata.block_table_tensor[:num_reqs] slot_mapping = common_attn_metadata.slot_mapping[:num_input_tokens] - input_positions = common_attn_metadata.positions[: - num_input_tokens].long( - ) + input_positions = common_attn_metadata.positions[:num_input_tokens].long() - cum_query_lens = common_attn_metadata.query_start_loc[1:num_reqs + 1] + cum_query_lens = common_attn_metadata.query_start_loc[1 : num_reqs + 1] seq_lens = common_attn_metadata.seq_lens[:num_reqs] cos, sin = get_cos_and_sin_mla(input_positions, True) @@ -216,8 +222,7 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]): local_end = min(local_end_with_pad, num_actual_tokens) pad_size = num_tokens_pad - cos.shape[0] - assert cos.shape == sin.shape, \ - f"cos.shape must be equal to sin.shape, got {cos.shape} and {sin.shape}" + assert cos.shape == sin.shape, f"cos.shape must be equal to sin.shape, got {cos.shape} and {sin.shape}" if pad_size > 0: cos = nn.functional.pad(cos, (0, 0, 0, 0, 0, 0, 0, pad_size)) @@ -225,9 +230,7 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]): pad_size_slot = num_tokens_pad - slot_mapping.shape[0] if pad_size_slot > 0: - slot_mapping = nn.functional.pad(slot_mapping, - (0, pad_size_slot), - value=-1) + slot_mapping = nn.functional.pad(slot_mapping, (0, pad_size_slot), value=-1) else: slot_mapping = slot_mapping[:num_tokens_pad] slot_mapping_cp = slot_mapping[local_start:local_end_with_pad] @@ -235,15 +238,18 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]): cos = cos[local_start:local_end_with_pad] sin = sin[local_start:local_end_with_pad] - assert cos.shape[0] == num_tokens_per_device, \ + assert cos.shape[0] == num_tokens_per_device, ( f"cos.shape[0] must be equal to num_tokens_per_device, \ got {cos.shape[0]} and {num_tokens_per_device}" - assert slot_mapping_cp.shape[0] == num_tokens_per_device, \ + ) + assert slot_mapping_cp.shape[0] == num_tokens_per_device, ( f"slot_mapping_cp.shape[0] must be equal to num_tokens_per_device, \ got {slot_mapping_cp.shape[0]} and {num_tokens_per_device}" - assert slot_mapping.shape[0] == num_tokens_pad, \ + ) + assert slot_mapping.shape[0] == num_tokens_pad, ( f"slot_mapping.shape[0] must be equal to num_tokens_pad, \ got {slot_mapping.shape[0]} and {num_tokens_pad}" + ) actual_seq_lengths_query = self.actual_seq_lengths_query actual_seq_lengths_key = self.actual_seq_lengths_key @@ -291,31 +297,26 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]): seq_lens=seq_lens, slot_mapping=slot_mapping, head_dim=self.model_config.get_head_size(), - attn_mask=self.attn_mask_builder.get_attention_mask( - self.model_config), + attn_mask=self.attn_mask_builder.get_attention_mask(self.model_config), attn_state=common_attn_metadata.attn_state, block_tables=block_table, sin=sin[:num_input_tokens], cos=cos[:num_input_tokens], - dsa_cp_context=dsa_cp_context) + dsa_cp_context=dsa_cp_context, + ) def build_for_graph_capture( self, common_attn_metadata: AscendCommonAttentionMetadata, attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly, ): - if attn_state in { - AscendAttentionState.DecodeOnly, - AscendAttentionState.SpecDecoding - }: + if attn_state in {AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding}: attn_metadata = self.build( common_prefix_len=0, common_attn_metadata=common_attn_metadata, ) else: - raise NotImplementedError( - "Currently we only support building dummy metadata for DecodeOnly state" - ) + raise NotImplementedError("Currently we only support building dummy metadata for DecodeOnly state") attn_metadata.attn_state = attn_state return attn_metadata @@ -326,8 +327,9 @@ class AscendSFAImpl(MLAAttentionImpl): NOTE: Please read the comment at the top of the file before trying to understand this class """ + # Supports forward using the all-gather o_proj weight for decode requests when Sharded CP is enabled. - o_proj_full_pool: Optional[torch.Tensor] = None + o_proj_full_pool: torch.Tensor | None = None def __init__( self, @@ -335,12 +337,12 @@ class AscendSFAImpl(MLAAttentionImpl): head_size: int, scale: float, num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], + alibi_slopes: list[float] | None, + sliding_window: int | None, kv_cache_dtype: str, - logits_soft_cap: Optional[float], + logits_soft_cap: float | None, attn_type: str, - kv_sharing_target_layer_name: Optional[str], + kv_sharing_target_layer_name: str | None, **kwargs, ) -> None: self.num_heads = num_heads @@ -350,26 +352,25 @@ class AscendSFAImpl(MLAAttentionImpl): self.kv_cache_dtype = kv_cache_dtype # MLA Args - self.q_lora_rank = kwargs['q_lora_rank'] - self.kv_lora_rank = kwargs['kv_lora_rank'] - self.qk_nope_head_dim = kwargs['qk_nope_head_dim'] - self.qk_rope_head_dim = kwargs['qk_rope_head_dim'] - self.qk_head_dim = kwargs['qk_head_dim'] - self.v_head_dim = kwargs['v_head_dim'] - self.rotary_emb = kwargs['rotary_emb'] - self.q_proj = kwargs['q_proj'] if self.q_lora_rank is None else kwargs[ - 'q_b_proj'] - self.fused_qkv_a_proj = kwargs.get('fused_qkv_a_proj', None) - self.kv_b_proj = kwargs['kv_b_proj'] - self.o_proj = kwargs['o_proj'] - self.indexer = kwargs['indexer'] - self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None) - self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None) - self.q_a_layernorm = kwargs.get('q_a_layernorm', None) + self.q_lora_rank = kwargs["q_lora_rank"] + self.kv_lora_rank = kwargs["kv_lora_rank"] + self.qk_nope_head_dim = kwargs["qk_nope_head_dim"] + self.qk_rope_head_dim = kwargs["qk_rope_head_dim"] + self.qk_head_dim = kwargs["qk_head_dim"] + self.v_head_dim = kwargs["v_head_dim"] + self.rotary_emb = kwargs["rotary_emb"] + self.q_proj = kwargs["q_proj"] if self.q_lora_rank is None else kwargs["q_b_proj"] + self.fused_qkv_a_proj = kwargs.get("fused_qkv_a_proj") + self.kv_b_proj = kwargs["kv_b_proj"] + self.o_proj = kwargs["o_proj"] + self.indexer = kwargs["indexer"] + self.kv_a_proj_with_mqa = kwargs.get("kv_a_proj_with_mqa") + self.kv_a_layernorm = kwargs.get("kv_a_layernorm") + self.q_a_layernorm = kwargs.get("q_a_layernorm") self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.tp_size = get_tensor_model_parallel_world_size() self.tp_rank = get_tp_group().rank_in_group - self.q_b_proj = kwargs['q_b_proj'] + self.q_b_proj = kwargs["q_b_proj"] ascend_config = get_ascend_config() self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp @@ -383,7 +384,9 @@ class AscendSFAImpl(MLAAttentionImpl): self.local_num_heads = self.num_heads self.vllm_config = get_current_vllm_config() - self.is_kv_producer = self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer + self.is_kv_producer = ( + self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer + ) # indexer param self.n_head: int = self.indexer.n_head # 64 @@ -400,38 +403,38 @@ class AscendSFAImpl(MLAAttentionImpl): self.local_num_heads = self.num_heads * self.tp_size if self.enable_dsa_cp_prefill_only: self.layer_sharding_kwargs = [] - for layer_name in (get_ascend_config().layer_sharding or []): + for layer_name in get_ascend_config().layer_sharding or []: if layer_name in kwargs: self.layer_sharding_kwargs.append(kwargs[layer_name]) else: logger.warning_once( - f"[SFAImpl init] Layer '{layer_name}' not found in kwargs for layer sharding, skipping sharding configuration" + f"[SFAImpl init] Layer '{layer_name}' not found in kwargs for layer sharding, " + "skipping sharding configuration" ) - register_all_layers_to_shard_weight_series( - self.layer_sharding_kwargs) + register_all_layers_to_shard_weight_series(self.layer_sharding_kwargs) def process_weights_after_loading(self, act_dtype: torch.dtype): # NOTE: We currently do not support quant kv_b_proj. assert isinstance(self.kv_b_proj.quant_method, UnquantizedLinearMethod) # NOTE: Weight will be reshaped next, we need to revert and transpose it. - kv_b_proj_weight = torch_npu.npu_format_cast( - self.kv_b_proj.weight.data, ACL_FORMAT_FRACTAL_ND).T + kv_b_proj_weight = torch_npu.npu_format_cast(self.kv_b_proj.weight.data, ACL_FORMAT_FRACTAL_ND).T assert kv_b_proj_weight.shape == ( - self.kv_lora_rank, self.local_num_heads * - (self.qk_nope_head_dim + self.v_head_dim)), ( - f"{kv_b_proj_weight.shape=}, " - f"{self.kv_lora_rank=}, " - f"{self.local_num_heads=}, " - f"{self.qk_nope_head_dim=}, " - f"{self.v_head_dim=}") + self.kv_lora_rank, + self.local_num_heads * (self.qk_nope_head_dim + self.v_head_dim), + ), ( + f"{kv_b_proj_weight.shape=}, " + f"{self.kv_lora_rank=}, " + f"{self.local_num_heads=}, " + f"{self.qk_nope_head_dim=}, " + f"{self.v_head_dim=}" + ) kv_b_proj_weight = kv_b_proj_weight.view( self.kv_lora_rank, self.local_num_heads, self.qk_nope_head_dim + self.v_head_dim, ) - W_UK, W_UV = kv_b_proj_weight.split( - [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + W_UK, W_UV = kv_b_proj_weight.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) # Convert from (L, N, V) to (N, L, V) self.W_UV = W_UV.transpose(0, 1).contiguous() @@ -445,10 +448,9 @@ class AscendSFAImpl(MLAAttentionImpl): dispose_layer(self.kv_b_proj) if self.enable_dsa_cp: if self.enable_dsa_cp_prefill_only: - for layer in (self.layer_sharding_kwargs or []): + for layer in self.layer_sharding_kwargs or []: if is_hidden_layer(layer): - post_process_after_loading_for_shard_weight_series( - layer) + post_process_after_loading_for_shard_weight_series(layer) else: self._init_o_proj_tp_full_params() @@ -459,15 +461,14 @@ class AscendSFAImpl(MLAAttentionImpl): None, ) reasons = [] - if self.fused_qkv_a_proj is None or not isinstance( - quant_method, AscendW8A8LinearMethod): + if self.fused_qkv_a_proj is None or not isinstance(quant_method, AscendW8A8LinearMethod): reasons.append( "Currently mlapo only supports W8A8 quantization in SFA scenario." "Some layers in your model are not quantized with W8A8," - "thus mlapo is disabled for these layers.") + "thus mlapo is disabled for these layers." + ) if self.enable_dsa_cp: - reasons.append("Currently mlapo does not support SFA with CP," - "thus mlapo is disabled for these layers.") + reasons.append("Currently mlapo does not support SFA with CP,thus mlapo is disabled for these layers.") if reasons: self.enable_mlapo = False for msg in reasons: @@ -480,32 +481,31 @@ class AscendSFAImpl(MLAAttentionImpl): def _v_up_proj(self, x): num_input_tokens, _, _ = x.shape - if x.dtype in [torch.float16, torch.bfloat16] \ - and hasattr(torch.ops._C_ascend, "batch_matmul_transpose") \ - and num_input_tokens <= BMM_TRANS_MAX_SUPPORTED_TOKENS: + if ( + x.dtype in [torch.float16, torch.bfloat16] + and hasattr(torch.ops._C_ascend, "batch_matmul_transpose") + and num_input_tokens <= BMM_TRANS_MAX_SUPPORTED_TOKENS + ): x = x.view(-1, self.local_num_heads, self.kv_lora_rank) - res = torch.empty((num_input_tokens, self.local_num_heads, self.v_head_dim), - dtype=x.dtype, - device=x.device) + res = torch.empty((num_input_tokens, self.local_num_heads, self.v_head_dim), dtype=x.dtype, device=x.device) torch.ops._C_ascend.batch_matmul_transpose(x, self.W_UV, res) x = res.reshape(-1, self.local_num_heads * self.v_head_dim) else: # Convert from (B, N, L) to (N, B, L) - x = x.view(-1, self.local_num_heads, - self.kv_lora_rank).transpose(0, 1) + x = x.view(-1, self.local_num_heads, self.kv_lora_rank).transpose(0, 1) # # Multiply (N, B, L) x (N, L, V) -> (N, B, V) x = torch.bmm(x, self.W_UV) # # Convert from (N, B, V) to (B, N * V) - x = x.transpose(0, - 1).reshape(-1, - self.local_num_heads * self.v_head_dim) + x = x.transpose(0, 1).reshape(-1, self.local_num_heads * self.v_head_dim) return x # Return `ql_nope`, `q_pe` def _q_proj_and_k_up_proj(self, x): - q_nope, q_pe = self.q_proj(x)[0]\ - .view(-1, self.local_num_heads, self.qk_head_dim)\ + q_nope, q_pe = ( + self.q_proj(x)[0] + .view(-1, self.local_num_heads, self.qk_head_dim) .split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + ) # Convert from (B, N, P) to (N, B, P) q_nope = q_nope.transpose(0, 1) @@ -519,27 +519,26 @@ class AscendSFAImpl(MLAAttentionImpl): kv_no_split: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, - kv_cache: Tuple, + kv_cache: tuple, slots: torch.Tensor, ): B = kv_no_split.shape[0] N = self.num_kv_heads S = 1 # npu_kv_rmsnorm_rope_cache needs [B, N, S, D] - kv_no_split = kv_no_split.view( - B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) + kv_no_split = kv_no_split.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) cache_mode = "PA" if self.enable_dsa_cp: _, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache( kv_no_split, - self.kv_a_layernorm.weight, + self.kv_a_layernorm.weight, # type: ignore[union-attr] cos, sin, slots.to(torch.int64), kv_cache[1], kv_cache[0], - epsilon=self.kv_a_layernorm.variance_epsilon, + epsilon=self.kv_a_layernorm.variance_epsilon, # type: ignore[union-attr] cache_mode=cache_mode, is_output_kv=True, ) @@ -547,13 +546,13 @@ class AscendSFAImpl(MLAAttentionImpl): else: torch_npu.npu_kv_rmsnorm_rope_cache( kv_no_split, - self.kv_a_layernorm.weight, + self.kv_a_layernorm.weight, # type: ignore[union-attr] cos, sin, slots.to(torch.int64), kv_cache[1], kv_cache[0], - epsilon=self.kv_a_layernorm.variance_epsilon, + epsilon=self.kv_a_layernorm.variance_epsilon, # type: ignore[union-attr] cache_mode=cache_mode, ) return None, None @@ -577,78 +576,53 @@ class AscendSFAImpl(MLAAttentionImpl): assert self.kv_a_proj_with_mqa is None assert self.fused_qkv_a_proj is not None - kv_a_proj_wt = self.fused_qkv_a_proj.weight.data[ - ..., self.q_lora_rank:].contiguous() - q_a_proj_wt = self.fused_qkv_a_proj.weight.data[ - ..., :self.q_lora_rank].contiguous() + kv_a_proj_wt = self.fused_qkv_a_proj.weight.data[..., self.q_lora_rank :].contiguous() + q_a_proj_wt = self.fused_qkv_a_proj.weight.data[..., : self.q_lora_rank].contiguous() kv_a_proj_wt = kv_a_proj_wt.t().contiguous() kv_a_proj_wt = trans_rope_weight(kv_a_proj_wt, self.qk_rope_head_dim) kv_a_proj_wt = kv_a_proj_wt.t().contiguous() wd_qkv = torch.cat((kv_a_proj_wt, q_a_proj_wt), dim=-1) wd_qkv = wd_qkv.t().contiguous() - wd_qkv = transdata(wd_qkv, - block_size=(16, 32)).unsqueeze(0).contiguous() + wd_qkv = transdata(wd_qkv, block_size=(16, 32)).unsqueeze(0).contiguous() self.wd_qkv = torch_npu.npu_format_cast(wd_qkv, 29) - kv_a_proj_deq_scl = self.fused_qkv_a_proj.deq_scale[ - self.q_lora_rank:].contiguous() - q_a_proj_deq_scl = self.fused_qkv_a_proj.deq_scale[:self. - q_lora_rank].contiguous( - ) - kv_a_proj_deq_scl = kv_a_proj_deq_scl.reshape( - self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous() - kv_a_proj_deq_scl = trans_rope_weight(kv_a_proj_deq_scl, - self.qk_rope_head_dim) - kv_a_proj_deq_scl = kv_a_proj_deq_scl.view( - self.kv_lora_rank + self.qk_rope_head_dim).contiguous() - self.deq_scale_qkv = torch.cat((kv_a_proj_deq_scl, q_a_proj_deq_scl), - dim=-1).contiguous() + kv_a_proj_deq_scl = self.fused_qkv_a_proj.deq_scale[self.q_lora_rank :].contiguous() + q_a_proj_deq_scl = self.fused_qkv_a_proj.deq_scale[: self.q_lora_rank].contiguous() + kv_a_proj_deq_scl = kv_a_proj_deq_scl.reshape(self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous() + kv_a_proj_deq_scl = trans_rope_weight(kv_a_proj_deq_scl, self.qk_rope_head_dim) + kv_a_proj_deq_scl = kv_a_proj_deq_scl.view(self.kv_lora_rank + self.qk_rope_head_dim).contiguous() + self.deq_scale_qkv = torch.cat((kv_a_proj_deq_scl, q_a_proj_deq_scl), dim=-1).contiguous() - kv_a_proj_qt_bias = self.fused_qkv_a_proj.quant_bias[ - self.q_lora_rank:].contiguous() - q_a_proj_qt_bias = self.fused_qkv_a_proj.quant_bias[:self. - q_lora_rank].contiguous( - ) + kv_a_proj_qt_bias = self.fused_qkv_a_proj.quant_bias[self.q_lora_rank :].contiguous() + q_a_proj_qt_bias = self.fused_qkv_a_proj.quant_bias[: self.q_lora_rank].contiguous() - kv_a_proj_qt_bias = kv_a_proj_qt_bias.reshape( - self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous() - kv_a_proj_qt_bias = trans_rope_weight(kv_a_proj_qt_bias, - self.qk_rope_head_dim) - kv_a_proj_qt_bias = kv_a_proj_qt_bias.view( - self.kv_lora_rank + self.qk_rope_head_dim).contiguous() - self.quant_bias_qkv = torch.cat((kv_a_proj_qt_bias, q_a_proj_qt_bias), - dim=-1).contiguous() + kv_a_proj_qt_bias = kv_a_proj_qt_bias.reshape(self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous() + kv_a_proj_qt_bias = trans_rope_weight(kv_a_proj_qt_bias, self.qk_rope_head_dim) + kv_a_proj_qt_bias = kv_a_proj_qt_bias.view(self.kv_lora_rank + self.qk_rope_head_dim).contiguous() + self.quant_bias_qkv = torch.cat((kv_a_proj_qt_bias, q_a_proj_qt_bias), dim=-1).contiguous() wu_q = self.q_proj.weight.data - wu_q = wu_q.t().reshape(self.num_heads, - self.qk_nope_head_dim + self.qk_rope_head_dim, - -1) + wu_q = wu_q.t().reshape(self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1) wu_q = trans_rope_weight(wu_q, self.qk_rope_head_dim) - wu_q = wu_q.reshape( - self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim), - -1) + wu_q = wu_q.reshape(self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim), -1) wu_q = transdata(wu_q, block_size=(16, 32)).unsqueeze(0).contiguous() self.wu_q = torch_npu.npu_format_cast(wu_q, 29) qb_deq_scl = self.q_proj.deq_scale.data - qb_deq_scl = qb_deq_scl.reshape( - self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1) + qb_deq_scl = qb_deq_scl.reshape(self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1) qb_deq_scl = trans_rope_weight(qb_deq_scl, self.qk_rope_head_dim) - self.qb_deq_scl = qb_deq_scl.reshape( - self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim)) + self.qb_deq_scl = qb_deq_scl.reshape(self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim)) qb_qt_bias = self.q_proj.quant_bias.data - qb_qt_bias = qb_qt_bias.reshape( - self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1) + qb_qt_bias = qb_qt_bias.reshape(self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1) qb_qt_bias = trans_rope_weight(qb_qt_bias, self.qk_rope_head_dim) - self.qb_qt_bias = qb_qt_bias.reshape( - self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim)) + self.qb_qt_bias = qb_qt_bias.reshape(self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim)) device = self.q_proj.weight.device - self.gamma1 = self.q_a_layernorm.weight.data - self.beta1 = self.q_a_layernorm.bias.data - self.gamma2 = self.kv_a_layernorm.weight.data + self.gamma1 = self.q_a_layernorm.weight.data # type: ignore[union-attr] + self.beta1 = self.q_a_layernorm.bias.data # type: ignore[union-attr] + self.gamma2 = self.kv_a_layernorm.weight.data # type: ignore[union-attr] self.quant_scale0 = self.fused_qkv_a_proj.input_scale.data self.quant_offset0 = self.fused_qkv_a_proj.input_offset.data self.quant_scale1 = self.q_proj.input_scale.data @@ -659,9 +633,11 @@ class AscendSFAImpl(MLAAttentionImpl): # On KV consumers (decode-only) MLAPO uses the transformed weights built above; # the original fused_qkv_a_proj/q_proj weights and quant params are no longer # referenced, so drop them to save memory. - if self.vllm_config.kv_transfer_config is not None and \ - self.vllm_config.kv_transfer_config.is_kv_consumer and \ - self.vllm_config.scheduler_config.max_num_batched_tokens <= MLAPO_MAX_SUPPORTED_TOKENS: + if ( + self.vllm_config.kv_transfer_config is not None + and self.vllm_config.kv_transfer_config.is_kv_consumer + and self.vllm_config.scheduler_config.max_num_batched_tokens <= MLAPO_MAX_SUPPORTED_TOKENS + ): self.fused_qkv_a_proj.weight = None self.fused_qkv_a_proj.deq_scale = None self.fused_qkv_a_proj.quant_bias = None @@ -673,13 +649,12 @@ class AscendSFAImpl(MLAAttentionImpl): def _sfa_preprocess_decode( self, hidden_states: torch.Tensor, - kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + kv_cache: tuple[torch.Tensor, torch.Tensor, torch.Tensor], attn_metadata: M, need_gather_q_kv: bool, num_input_tokens: int, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( - hidden_states.contiguous(), need_gather_q_kv) + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(hidden_states.contiguous(), need_gather_q_kv) k_nope, k_pe = kv_cache[0], kv_cache[1] ql_nope = torch.empty( (num_input_tokens, self.W_UK_T.shape[0], k_nope.shape[-1]), @@ -734,17 +709,17 @@ class AscendSFAImpl(MLAAttentionImpl): self, layer_name, hidden_states: torch.Tensor, # query in unified attn - kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + kv_cache: tuple[torch.Tensor, torch.Tensor, torch.Tensor], attn_metadata: M, need_gather_q_kv: bool = False, - output: Optional[torch.Tensor] = None, + output: torch.Tensor | None = None, ) -> torch.Tensor: assert output is not None, "Output tensor must be provided." forward_context = get_forward_context() if attn_metadata is None: # Profiling run. if self.enable_dsa_cp_prefill_only and not forward_context.in_profile_run: - for layer in (self.layer_sharding_kwargs or []): + for layer in self.layer_sharding_kwargs or []: if is_hidden_layer(layer): reach_layer_for_shard_weight_series(layer) return output.fill_(0) @@ -761,12 +736,13 @@ class AscendSFAImpl(MLAAttentionImpl): # all-gather o_proj weight for prefill stage of PD mix node o_proj_full_handle = None - # if is PD mix stage, using original TP o_proj weight, and also need to full gather for o_proj weight for prefill stage. + # if is PD mix stage, using original TP o_proj weight, and also need to full gather for o_proj + # weight for prefill stage. should_shard_weight = self.enable_dsa_cp_prefill_only or attn_metadata.attn_state not in { - AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding + AscendAttentionState.DecodeOnly, + AscendAttentionState.SpecDecoding, } - if self.enable_mlapo and num_input_tokens <= MLAPO_MAX_SUPPORTED_TOKENS: hidden_states, ql_nope, q_pe, q_c = self._sfa_preprocess_decode( hidden_states=hidden_states, @@ -776,35 +752,30 @@ class AscendSFAImpl(MLAAttentionImpl): num_input_tokens=num_input_tokens, ) q, k = self.indexer_select_pre_process( - x=hidden_states, - qr=q_c, - cos=cos, - sin=sin, - need_gather_q_kv=need_gather_q_kv) + x=hidden_states, qr=q_c, cos=cos, sin=sin, need_gather_q_kv=need_gather_q_kv + ) else: assert self.fused_qkv_a_proj is not None, "q lora is required for DSA." - maybe_npu_prefetch(inputs=self.fused_qkv_a_proj.weight, - dependency=hidden_states, - enabled=self.enable_prefetch) + maybe_npu_prefetch( + inputs=self.fused_qkv_a_proj.weight, dependency=hidden_states, enabled=self.enable_prefetch + ) qkv_lora = self.fused_qkv_a_proj(hidden_states)[0] q_c, kv_no_split = qkv_lora.split( [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1, ) + assert self.q_a_layernorm is not None, "q_a_layernorm must be initialized" q_c = self.q_a_layernorm(q_c) # Process for Flash Comm V1 if need_gather_q_kv: - q_c = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( - q_c.contiguous(), need_gather_q_kv) + q_c = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(q_c.contiguous(), need_gather_q_kv) kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( - kv_no_split.contiguous(), need_gather_q_kv) + kv_no_split.contiguous(), need_gather_q_kv + ) q, k = self.indexer_select_pre_process( - x=hidden_states, - qr=q_c, - cos=cos, - sin=sin, - need_gather_q_kv=need_gather_q_kv) + x=hidden_states, qr=q_c, cos=cos, sin=sin, need_gather_q_kv=need_gather_q_kv + ) wait_for_kv_layer_from_connector(layer_name) @@ -815,22 +786,20 @@ class AscendSFAImpl(MLAAttentionImpl): actual_seq_lengths_query = attn_metadata.dsa_cp_context.actual_seq_lengths_query actual_seq_lengths_key = attn_metadata.dsa_cp_context.actual_seq_lengths_key - k_pe, k_nope = self.exec_kv(kv_no_split, cos, sin, kv_cache, - slot_mapping) + k_pe, k_nope = self.exec_kv(kv_no_split, cos, sin, kv_cache, slot_mapping) if self.enable_dsa_cp: assert k_pe is not None assert k_nope is not None # support all_gather kv async for communication calculation overlap fused_kv_no_split, kv_ag_handle = all_gather_async( - torch.cat([ - k_pe.view(-1, k_pe.shape[-1]), - k_nope.view(-1, k_nope.shape[-1]), - k.view(-1, k.shape[-1]) - ], - dim=1), + torch.cat( + [k_pe.view(-1, k_pe.shape[-1]), k_nope.view(-1, k_nope.shape[-1]), k.view(-1, k.shape[-1])], + dim=1, + ), get_tp_group(), - async_op=should_shard_weight) + async_op=should_shard_weight, + ) ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c) q_pe = self.rope_single(q_pe, cos, sin) @@ -840,34 +809,27 @@ class AscendSFAImpl(MLAAttentionImpl): kv_ag_handle.wait() if self.enable_dsa_cp_prefill_only: - for layer in (self.layer_sharding_kwargs or []): + for layer in self.layer_sharding_kwargs or []: if is_hidden_layer(layer): reach_layer_for_shard_weight_series(layer) elif should_shard_weight: _, o_proj_full_handle = all_gather_async( - self.o_proj_tp_weight, - get_tp_group(), - output=AscendSFAImpl.o_proj_full_pool) + self.o_proj_tp_weight, get_tp_group(), output=AscendSFAImpl.o_proj_full_pool + ) if kv_cache is not None: assert fused_kv_no_split is not None - k_pe, k_nope, k = fused_kv_no_split.split([ - self.qk_rope_head_dim, self.kv_lora_rank, self.head_dim - ], - dim=-1) + k_pe, k_nope, k = fused_kv_no_split.split( + [self.qk_rope_head_dim, self.kv_lora_rank, self.head_dim], dim=-1 + ) slot_mapping = attn_metadata.slot_mapping.view(-1, 1) - torch_npu.npu_scatter_nd_update_( - kv_cache[0].view(-1, k_nope.shape[-1]), slot_mapping, - k_nope) - torch_npu.npu_scatter_nd_update_( - kv_cache[1].view(-1, k_pe.shape[-1]), slot_mapping, - k_pe) + torch_npu.npu_scatter_nd_update_(kv_cache[0].view(-1, k_nope.shape[-1]), slot_mapping, k_nope) + torch_npu.npu_scatter_nd_update_(kv_cache[1].view(-1, k_pe.shape[-1]), slot_mapping, k_pe) if kv_cache is not None: torch_npu.npu_scatter_nd_update_( - kv_cache[2].view(-1, k.shape[-1]), - attn_metadata.slot_mapping.view(-1, 1), - k.view(-1, k.shape[-1])) # b, s, n, d + kv_cache[2].view(-1, k.shape[-1]), attn_metadata.slot_mapping.view(-1, 1), k.view(-1, k.shape[-1]) + ) # b, s, n, d topk_indices = self.indexer_select_post_process( x=hidden_states, @@ -880,7 +842,8 @@ class AscendSFAImpl(MLAAttentionImpl): sin=sin, actual_seq_lengths_query=actual_seq_lengths_query, actual_seq_lengths_key=actual_seq_lengths_key, - need_gather_q_kv=need_gather_q_kv) + need_gather_q_kv=need_gather_q_kv, + ) attn_output = torch.ops._C_ascend.npu_sparse_flash_attention( query=ql_nope, @@ -900,10 +863,12 @@ class AscendSFAImpl(MLAAttentionImpl): ) attn_output = self._v_up_proj(attn_output) - maybe_npu_prefetch(inputs=self.o_proj.weight, - dependency=attn_output, - max_size=MAX_O_PROJ_PREFETCH_SIZE, - enabled=self.enable_prefetch) + maybe_npu_prefetch( + inputs=self.o_proj.weight, + dependency=attn_output, + max_size=MAX_O_PROJ_PREFETCH_SIZE, + enabled=self.enable_prefetch, + ) if self.enable_dsa_cp and not self.enable_dsa_cp_prefill_only: # When using SFA-CP with pd mixed, o_proj has two cases: @@ -913,7 +878,8 @@ class AscendSFAImpl(MLAAttentionImpl): attn_output=attn_output, output=output, o_proj_full_handle=o_proj_full_handle, - should_shard_weight=should_shard_weight) + should_shard_weight=should_shard_weight, + ) if not require_o_proj_forward: return result attn_output = result @@ -933,8 +899,7 @@ class AscendSFAImpl(MLAAttentionImpl): need_gather_q_kv: bool = False, ): k_proj, _ = self.wk(x) # [b,s,7168] @ [7168,128] = [b,s,128] - k_proj = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( - k_proj, need_gather_q_kv) + k_proj = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(k_proj, need_gather_q_kv) k = self.k_norm(k_proj).unsqueeze(1) k = k.view(-1, 1, self.head_dim) @@ -944,17 +909,9 @@ class AscendSFAImpl(MLAAttentionImpl): cos = cos.view(-1, self.qk_rope_head_dim) sin = sin.view(-1, self.qk_rope_head_dim) - q, k = rope_forward_triton(q, - k, - cos, - sin, - rope_dim=self.qk_rope_head_dim, - is_neox_style=True) + q, k = rope_forward_triton(q, k, cos, sin, rope_dim=self.qk_rope_head_dim, is_neox_style=True) else: - k_pe, k_nope = torch.split( - k, - [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], - dim=-1) + k_pe, k_nope = torch.split(k, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], dim=-1) cos = cos.view(-1, 1, 1, self.qk_rope_head_dim) sin = sin.view(-1, 1, 1, self.qk_rope_head_dim) @@ -972,9 +929,9 @@ class AscendSFAImpl(MLAAttentionImpl): self, x: torch.Tensor, qr: torch.Tensor, - q: Optional[torch.Tensor], + q: torch.Tensor | None, k: torch.Tensor, - kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + kv_cache: tuple[torch.Tensor, torch.Tensor, torch.Tensor], attn_metadata: M, cos: torch.Tensor, sin: torch.Tensor, @@ -988,9 +945,8 @@ class AscendSFAImpl(MLAAttentionImpl): cos_q, sin_q = cos, sin q_pe, q_nope = torch.split( - q, - [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], - dim=-1) # [b,s,64,64+64] + q, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], dim=-1 + ) # [b,s,64,64+64] q_pe = q_pe.unsqueeze(2) q_pe = torch_npu.npu_rotary_mul(q_pe, cos_q, sin_q) @@ -1000,17 +956,14 @@ class AscendSFAImpl(MLAAttentionImpl): if kv_cache is not None: if self.is_kv_producer: attn_metadata.reshape_cache_event = torch.npu.Event() - torch_npu.npu_scatter_nd_update_(kv_cache[2].view(-1, k.shape[-1]), - attn_metadata.slot_mapping.view( - -1, 1), - k.view(-1, - k.shape[-1])) # b, s, n, d + torch_npu.npu_scatter_nd_update_( + kv_cache[2].view(-1, k.shape[-1]), attn_metadata.slot_mapping.view(-1, 1), k.view(-1, k.shape[-1]) + ) # b, s, n, d if self.is_kv_producer: attn_metadata.reshape_cache_event.record() weights, _ = self.weights_proj(x) - weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( - weights, need_gather_q_kv) + weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(weights, need_gather_q_kv) block_table = attn_metadata.block_tables @@ -1024,14 +977,15 @@ class AscendSFAImpl(MLAAttentionImpl): layout_query="TND", layout_key="PA_BSND", sparse_count=2048, - sparse_mode=3) + sparse_mode=3, + ) return topk_indices def _init_o_proj_tp_full_params(self): """ - Initialize TP-mode and Full-mode parameters for o_proj weight, + Initialize TP-mode and Full-mode parameters for o_proj weight, preparing for weight switching in PD mix stage. - + For PD mix stage: - Use original TP o_proj weight for decode phase - Need full-gather o_proj weight from all TP ranks for prefill phase @@ -1039,38 +993,33 @@ class AscendSFAImpl(MLAAttentionImpl): if AscendSFAImpl.o_proj_full_pool is None: sample = self.o_proj.weight AscendSFAImpl.o_proj_full_pool = torch.empty( - (sample.shape[0] * self.tp_size, sample.shape[1]), - dtype=sample.dtype, - device=sample.device) + (sample.shape[0] * self.tp_size, sample.shape[1]), dtype=sample.dtype, device=sample.device + ) # Save TP-mode parameters (original sharded weights) self.o_proj_tp_weight = self.o_proj.weight.clone().detach() - self.o_proj_tp_aclnn_input_scale = self.o_proj.aclnn_input_scale.clone( - ).detach() - self.o_proj_tp_aclnn_input_scale_reciprocal = self.o_proj.aclnn_input_scale_reciprocal.clone( - ).detach() - self.o_proj_tp_aclnn_input_offset = self.o_proj.aclnn_input_offset.clone( - ).detach() + self.o_proj_tp_aclnn_input_scale = self.o_proj.aclnn_input_scale.clone().detach() + self.o_proj_tp_aclnn_input_scale_reciprocal = self.o_proj.aclnn_input_scale_reciprocal.clone().detach() + self.o_proj_tp_aclnn_input_offset = self.o_proj.aclnn_input_offset.clone().detach() # Initially switch to TP mode for graph capture self.o_proj.weight.set_(self.o_proj_tp_weight) self.o_proj.aclnn_input_scale.set_(self.o_proj_tp_aclnn_input_scale) - self.o_proj.aclnn_input_scale_reciprocal.set_( - self.o_proj_tp_aclnn_input_scale_reciprocal) + self.o_proj.aclnn_input_scale_reciprocal.set_(self.o_proj_tp_aclnn_input_scale_reciprocal) self.o_proj.aclnn_input_offset.set_(self.o_proj_tp_aclnn_input_offset) # Precompute Full-mode quantization parameters by repeating TP parameters across all TP ranks - self.o_proj_full_aclnn_input_scale = self.o_proj.aclnn_input_scale.repeat( - self.tp_size) - self.o_proj_full_aclnn_input_scale_reciprocal = self.o_proj.aclnn_input_scale_reciprocal.repeat( - self.tp_size) - self.o_proj_full_aclnn_input_offset = self.o_proj.aclnn_input_offset.repeat( - self.tp_size) + self.o_proj_full_aclnn_input_scale = self.o_proj.aclnn_input_scale.repeat(self.tp_size) + self.o_proj_full_aclnn_input_scale_reciprocal = self.o_proj.aclnn_input_scale_reciprocal.repeat(self.tp_size) + self.o_proj_full_aclnn_input_offset = self.o_proj.aclnn_input_offset.repeat(self.tp_size) def _handle_o_proj_weight_switch_and_forward( - self, attn_output: torch.Tensor, output: torch.Tensor, - o_proj_full_handle: Optional[torch.distributed.Work], - should_shard_weight: bool) -> Tuple[torch.Tensor, bool]: + self, + attn_output: torch.Tensor, + output: torch.Tensor, + o_proj_full_handle: torch.distributed.Work | None, + should_shard_weight: bool, + ) -> tuple[torch.Tensor, bool]: """ Handle o_proj weight switching between TP-mode and Full-mode, and execute forward computation. """ @@ -1082,36 +1031,30 @@ class AscendSFAImpl(MLAAttentionImpl): # Switch o_proj to Full-mode (gathered weight from all TP ranks) self.o_proj.weight.set_(AscendSFAImpl.o_proj_full_pool) - self.o_proj.aclnn_input_scale.set_( - self.o_proj_full_aclnn_input_scale) - self.o_proj.aclnn_input_scale_reciprocal.set_( - self.o_proj_full_aclnn_input_scale_reciprocal) - self.o_proj.aclnn_input_offset.set_( - self.o_proj_full_aclnn_input_offset) + self.o_proj.aclnn_input_scale.set_(self.o_proj_full_aclnn_input_scale) + self.o_proj.aclnn_input_scale_reciprocal.set_(self.o_proj_full_aclnn_input_scale_reciprocal) + self.o_proj.aclnn_input_offset.set_(self.o_proj_full_aclnn_input_offset) # Apply quantization method and execute forward computation - output[...] = self.o_proj.quant_method.quant_method.apply( - self.o_proj, attn_output) + output[...] = self.o_proj.quant_method.quant_method.apply(self.o_proj, attn_output) # Switch o_proj back to TP-mode for subsequent decode operations self.o_proj.weight.set_(self.o_proj_tp_weight) - self.o_proj.aclnn_input_scale.set_( - self.o_proj_tp_aclnn_input_scale) - self.o_proj.aclnn_input_scale_reciprocal.set_( - self.o_proj_tp_aclnn_input_scale_reciprocal) - self.o_proj.aclnn_input_offset.set_( - self.o_proj_tp_aclnn_input_offset) + self.o_proj.aclnn_input_scale.set_(self.o_proj_tp_aclnn_input_scale) + self.o_proj.aclnn_input_scale_reciprocal.set_(self.o_proj_tp_aclnn_input_scale_reciprocal) + self.o_proj.aclnn_input_offset.set_(self.o_proj_tp_aclnn_input_offset) return output, False else: # For decode scenario: perform all-to-all communication on o_proj input activations # Reshape for all-to-all: [batch * seq, tp_size, head_dim] -> [tp_size, batch * seq, head_dim] - send = attn_output.view(-1, self.tp_size, self.num_heads * - self.v_head_dim).permute(1, 0, 2).reshape( - -1, self.num_heads * self.v_head_dim) + send = ( + attn_output.view(-1, self.tp_size, self.num_heads * self.v_head_dim) + .permute(1, 0, 2) + .reshape(-1, self.num_heads * self.v_head_dim) + ) attn_output = torch.empty_like(send) - torch.distributed.all_to_all_single( - attn_output, send, group=get_tp_group().device_group) + torch.distributed.all_to_all_single(attn_output, send, group=get_tp_group().device_group) return attn_output, True diff --git a/vllm_ascend/core/recompute_scheduler.py b/vllm_ascend/core/recompute_scheduler.py index 356583a3..45f3160d 100644 --- a/vllm_ascend/core/recompute_scheduler.py +++ b/vllm_ascend/core/recompute_scheduler.py @@ -21,26 +21,21 @@ from __future__ import annotations import time from collections import defaultdict from dataclasses import dataclass, fields -from typing import Type, Union from vllm._bc_linter import bc_linter_include from vllm.config import SchedulerConfig, VllmConfig from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorMetadata from vllm.distributed.kv_events import KVEventBatch -from vllm.distributed.kv_transfer.kv_connector.v1.base import \ - KVConnectorMetadata -from vllm.distributed.kv_transfer.kv_connector.v1.metrics import \ - KVConnectorStats +from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata +from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats from vllm.logger import init_logger from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.sched.async_scheduler import AsyncScheduler from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput -from vllm.v1.core.sched.request_queue import (SchedulingPolicy, - create_request_queue) +from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.core.sched.utils import check_stop, remove_all -from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput, - EngineCoreOutputs, FinishReason) +from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs, FinishReason from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus from vllm.v1.spec_decode.metrics import SpecDecodingStats @@ -51,26 +46,22 @@ logger = init_logger(__name__) @dataclass class RecomputeSchedulerConfig(SchedulerConfig): - scheduler_cls: Union[str, Type[object]] = ( - "vllm_ascend.core.recompute_scheduler.RecomputeScheduler") + scheduler_cls: str | type[object] = "vllm_ascend.core.recompute_scheduler.RecomputeScheduler" @classmethod def initialize_from_config(cls, vllm_config: VllmConfig): vllm_scheduler_config = vllm_config.scheduler_config scheduler_config = { field.name: getattr(vllm_scheduler_config, field.name) - for field in fields(vllm_scheduler_config) if field.init + for field in fields(vllm_scheduler_config) + if field.init } if vllm_scheduler_config.async_scheduling: - scheduler_config["scheduler_cls"] = ( - "vllm_ascend.core.recompute_scheduler.AsyncRecomputeScheduler") + scheduler_config["scheduler_cls"] = "vllm_ascend.core.recompute_scheduler.AsyncRecomputeScheduler" else: - scheduler_config["scheduler_cls"] = ( - "vllm_ascend.core.recompute_scheduler.RecomputeScheduler") - scheduler_config[ - "max_model_len"] = vllm_config.model_config.max_model_len - scheduler_config[ - "is_encoder_decoder"] = vllm_config.model_config.is_encoder_decoder + scheduler_config["scheduler_cls"] = "vllm_ascend.core.recompute_scheduler.RecomputeScheduler" + scheduler_config["max_model_len"] = vllm_config.model_config.max_model_len + scheduler_config["is_encoder_decoder"] = vllm_config.model_config.is_encoder_decoder return cls(**scheduler_config) @@ -125,33 +116,32 @@ class RecomputeScheduler(Scheduler): while req_index < len(self.running) and token_budget > 0: request = self.running[req_index] - if (request.num_output_placeholders > 0 - # This is (num_computed_tokens + 1) - (num_output_placeholders - 1). - # Since output placeholders are also included in the computed tokens - # count, we subtract (num_output_placeholders - 1) to remove any draft - # tokens, so that we can be sure no further steps are needed even if - # they are all rejected. - and request.num_computed_tokens + 2 - - request.num_output_placeholders - >= request.num_prompt_tokens + request.max_tokens): + if ( + request.num_output_placeholders > 0 + # This is (num_computed_tokens + 1) - (num_output_placeholders - 1). + # Since output placeholders are also included in the computed tokens + # count, we subtract (num_output_placeholders - 1) to remove any draft + # tokens, so that we can be sure no further steps are needed even if + # they are all rejected. + and request.num_computed_tokens + 2 - request.num_output_placeholders + >= request.num_prompt_tokens + request.max_tokens + ): # Async scheduling: Avoid scheduling an extra step when we are sure that # the previous step has reached request.max_tokens. We don't schedule # partial draft tokens since this prevents uniform decode optimizations. req_index += 1 continue - num_new_tokens = (request.num_tokens_with_spec + - request.num_output_placeholders - - request.num_computed_tokens) + num_new_tokens = ( + request.num_tokens_with_spec + request.num_output_placeholders - request.num_computed_tokens + ) if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens: num_new_tokens = self.scheduler_config.long_prefill_token_threshold num_new_tokens = min(num_new_tokens, token_budget) # Make sure the input position does not exceed the max model len. # This is necessary when using spec decoding. - num_new_tokens = min( - num_new_tokens, - self.max_model_len - 1 - request.num_computed_tokens) + num_new_tokens = min(num_new_tokens, self.max_model_len - 1 - request.num_computed_tokens) # Schedule encoder inputs. encoder_inputs_to_schedule = None @@ -209,9 +199,10 @@ class RecomputeScheduler(Scheduler): recomputed_req = self.running.pop() self.kv_cache_manager.free(recomputed_req) recomputed_reqs.append( - RecomputeReqInfo(recomputed_req.request_id, - recomputed_req.output_token_ids, - recomputed_req.client_index)) + RecomputeReqInfo( + recomputed_req.request_id, recomputed_req.output_token_ids, recomputed_req.client_index + ) + ) if recomputed_req == request: break else: @@ -223,28 +214,23 @@ class RecomputeScheduler(Scheduler): self.running.remove(preempted_req) if preempted_req in scheduled_running_reqs: scheduled_running_reqs.remove(preempted_req) - token_budget += num_scheduled_tokens[ - preempted_req.request_id] + token_budget += num_scheduled_tokens[preempted_req.request_id] req_to_new_blocks.pop(preempted_req.request_id) - num_scheduled_tokens.pop( - preempted_req.request_id) - scheduled_spec_decode_tokens.pop( - preempted_req.request_id, None) - preempted_encoder_inputs = scheduled_encoder_inputs.pop( - preempted_req.request_id, None) + num_scheduled_tokens.pop(preempted_req.request_id) + scheduled_spec_decode_tokens.pop(preempted_req.request_id, None) + preempted_encoder_inputs = scheduled_encoder_inputs.pop(preempted_req.request_id, None) if preempted_encoder_inputs: # Restore encoder compute budget if the preempted # request had encoder inputs scheduled in this step. num_embeds_to_restore = sum( - preempted_req.get_num_encoder_embeds(i) - for i in preempted_encoder_inputs) + preempted_req.get_num_encoder_embeds(i) for i in preempted_encoder_inputs + ) encoder_compute_budget += num_embeds_to_restore req_index -= 1 else: preempted_req = self.running.pop() - self._preempt_request(preempted_req, - scheduled_timestamp) + self._preempt_request(preempted_req, scheduled_timestamp) preempted_reqs.append(preempted_req) if preempted_req == request: # No more request to preempt. Cannot schedule this request. @@ -263,23 +249,20 @@ class RecomputeScheduler(Scheduler): # Speculative decode related. if request.spec_token_ids: - num_scheduled_spec_tokens = (num_new_tokens + - request.num_computed_tokens - - request.num_tokens - - request.num_output_placeholders) + num_scheduled_spec_tokens = ( + num_new_tokens + request.num_computed_tokens - request.num_tokens - request.num_output_placeholders + ) if num_scheduled_spec_tokens > 0: # Trim spec_token_ids list to num_scheduled_spec_tokens. del request.spec_token_ids[num_scheduled_spec_tokens:] - scheduled_spec_decode_tokens[request.request_id] = ( - request.spec_token_ids) + scheduled_spec_decode_tokens[request.request_id] = request.spec_token_ids # New spec tokens will be set in `update_draft_token_ids` before the # next step when applicable. request.spec_token_ids = [] # Encoder-related. if encoder_inputs_to_schedule: - scheduled_encoder_inputs[request.request_id] = ( - encoder_inputs_to_schedule) + scheduled_encoder_inputs[request.request_id] = encoder_inputs_to_schedule # Allocate the encoder cache. for i in encoder_inputs_to_schedule: self.encoder_cache_manager.allocate(request, i) @@ -294,8 +277,10 @@ class RecomputeScheduler(Scheduler): scheduled_loras: set[int] = set() if self.lora_config: scheduled_loras = set( - req.lora_request.lora_int_id for req in scheduled_running_reqs - if req.lora_request and req.lora_request.lora_int_id > 0) + req.lora_request.lora_int_id + for req in scheduled_running_reqs + if req.lora_request and req.lora_request.lora_int_id > 0 + ) assert len(scheduled_loras) <= self.lora_config.max_loras # Use a temporary RequestQueue to collect requests that need to be @@ -337,9 +322,14 @@ class RecomputeScheduler(Scheduler): # Check that adding the request still respects the max_loras # constraint. - if (self.lora_config and request.lora_request and - (len(scheduled_loras) == self.lora_config.max_loras and - request.lora_request.lora_int_id not in scheduled_loras)): + if ( + self.lora_config + and request.lora_request + and ( + len(scheduled_loras) == self.lora_config.max_loras + and request.lora_request.lora_int_id not in scheduled_loras + ) + ): # Scheduling would exceed max_loras, skip. self.waiting.pop_request() skipped_waiting_requests.prepend_request(request) @@ -351,14 +341,15 @@ class RecomputeScheduler(Scheduler): # Get already-cached tokens. if request.num_computed_tokens == 0: # Get locally-cached tokens. - new_computed_blocks, num_new_local_computed_tokens = ( - self.kv_cache_manager.get_computed_blocks(request)) + new_computed_blocks, num_new_local_computed_tokens = self.kv_cache_manager.get_computed_blocks( + request + ) # Get externally-cached tokens if using a KVConnector. if self.connector is not None: - ext_tokens, load_kv_async = ( - self.connector.get_num_new_matched_tokens( - request, num_new_local_computed_tokens)) + ext_tokens, load_kv_async = self.connector.get_num_new_matched_tokens( + request, num_new_local_computed_tokens + ) if ext_tokens is None: # The request cannot be scheduled because @@ -372,8 +363,7 @@ class RecomputeScheduler(Scheduler): num_external_computed_tokens = ext_tokens # Total computed tokens (local + external). - num_computed_tokens = (num_new_local_computed_tokens + - num_external_computed_tokens) + num_computed_tokens = num_new_local_computed_tokens + num_external_computed_tokens else: # KVTransfer: WAITING reqs have num_computed_tokens > 0 # after async KV recvs are completed. @@ -401,8 +391,7 @@ class RecomputeScheduler(Scheduler): # chunked prefill has to be enabled explicitly to allow # pooling requests to be chunked - if (not self.scheduler_config.enable_chunked_prefill - and num_new_tokens > token_budget): + if not self.scheduler_config.enable_chunked_prefill and num_new_tokens > token_budget: # If chunked_prefill is disabled, # we can stop the scheduling here. break @@ -433,9 +422,7 @@ class RecomputeScheduler(Scheduler): # extra block gets allocated which # creates a mismatch between the number # of local and remote blocks. - effective_lookahead_tokens = (0 if request.num_computed_tokens - == 0 else - self.num_lookahead_tokens) + effective_lookahead_tokens = 0 if request.num_computed_tokens == 0 else self.num_lookahead_tokens # Determine if we need to allocate cross-attention blocks. if self.is_encoder_decoder and request.has_encoder_inputs: @@ -443,8 +430,7 @@ class RecomputeScheduler(Scheduler): # always padded to the maximum length. If we support other # encoder-decoder models, this will need to be updated if we # want to only allocate what is needed. - num_encoder_tokens = ( - self.scheduler_config.max_num_encoder_input_tokens) + num_encoder_tokens = self.scheduler_config.max_num_encoder_input_tokens else: num_encoder_tokens = 0 @@ -488,20 +474,17 @@ class RecomputeScheduler(Scheduler): req_index += 1 self.running.append(request) if self.log_stats: - request.record_event(EngineCoreEventType.SCHEDULED, - scheduled_timestamp) + request.record_event(EngineCoreEventType.SCHEDULED, scheduled_timestamp) if request.status == RequestStatus.WAITING: scheduled_new_reqs.append(request) elif request.status == RequestStatus.PREEMPTED: scheduled_resumed_reqs.append(request) else: - raise RuntimeError( - f"Invalid request status: {request.status}") + raise RuntimeError(f"Invalid request status: {request.status}") if self.lora_config and request.lora_request: scheduled_loras.add(request.lora_request.lora_int_id) - req_to_new_blocks[request.request_id] = ( - self.kv_cache_manager.get_blocks(request.request_id)) + req_to_new_blocks[request.request_id] = self.kv_cache_manager.get_blocks(request.request_id) num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens request.status = RequestStatus.RUNNING @@ -511,8 +494,7 @@ class RecomputeScheduler(Scheduler): request.num_cached_tokens = num_computed_tokens # Encoder-related. if encoder_inputs_to_schedule: - scheduled_encoder_inputs[request.request_id] = ( - encoder_inputs_to_schedule) + scheduled_encoder_inputs[request.request_id] = encoder_inputs_to_schedule # Allocate the encoder cache. for i in encoder_inputs_to_schedule: self.encoder_cache_manager.allocate(request, i) @@ -522,8 +504,7 @@ class RecomputeScheduler(Scheduler): for i in external_load_encoder_input: self.encoder_cache_manager.allocate(request, i) if self.ec_connector is not None: - self.ec_connector.update_state_after_alloc( - request, i) + self.ec_connector.update_state_after_alloc(request, i) # Put back any skipped requests at the head of the waiting queue if skipped_waiting_requests: self.waiting.prepend_requests(skipped_waiting_requests) @@ -537,20 +518,15 @@ class RecomputeScheduler(Scheduler): # Since some requests in the RUNNING queue may not be scheduled in # this step, the total number of scheduled requests can be smaller than # len(self.running). - assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len( - scheduled_running_reqs) <= len(self.running) + assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len(scheduled_running_reqs) <= len(self.running) # Get the longest common prefix among all requests in the running queue. # This can be potentially used for cascade attention. - num_common_prefix_blocks = [0] * len( - self.kv_cache_config.kv_cache_groups) - with record_function_or_nullcontext( - "schedule: get_num_common_prefix_blocks"): + num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups) + with record_function_or_nullcontext("schedule: get_num_common_prefix_blocks"): if self.running: any_request = self.running[0] - num_common_prefix_blocks = ( - self.kv_cache_manager.get_num_common_prefix_blocks( - any_request.request_id)) + num_common_prefix_blocks = self.kv_cache_manager.get_num_common_prefix_blocks(any_request.request_id) # Construct the scheduler output. if self.use_v2_model_runner: @@ -561,17 +537,16 @@ class RecomputeScheduler(Scheduler): req, req_to_new_blocks[req.request_id].get_block_ids(), req._all_token_ids, - ) for req in scheduled_new_reqs + ) + for req in scheduled_new_reqs ] else: new_reqs_data = [ - NewRequestData.from_request( - req, req_to_new_blocks[req.request_id].get_block_ids()) + NewRequestData.from_request(req, req_to_new_blocks[req.request_id].get_block_ids()) for req in scheduled_new_reqs ] - with record_function_or_nullcontext( - "schedule: make_cached_request_data"): + with record_function_or_nullcontext("schedule: make_cached_request_data"): cached_reqs_data = self._make_cached_request_data( scheduled_running_reqs, scheduled_resumed_reqs, @@ -592,15 +567,13 @@ class RecomputeScheduler(Scheduler): scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, scheduled_encoder_inputs=scheduled_encoder_inputs, num_common_prefix_blocks=num_common_prefix_blocks, - preempted_req_ids={req.request_id - for req in preempted_reqs}, + preempted_req_ids={req.request_id for req in preempted_reqs}, # finished_req_ids is an existing state in the scheduler, # instead of being newly scheduled in this step. # It contains the request IDs that are finished in between # the previous and the current steps. finished_req_ids=self.finished_req_ids, - free_encoder_mm_hashes=self.encoder_cache_manager. - get_freed_mm_hashes(), + free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(), recomputed_reqs=recomputed_reqs, ) @@ -609,14 +582,12 @@ class RecomputeScheduler(Scheduler): # 2. Wrap up all the KV cache load / save ops into an opaque object # 3. Clear the internal states of the connector if self.connector is not None: - meta: KVConnectorMetadata = self.connector.build_connector_meta( - scheduler_output) + meta: KVConnectorMetadata = self.connector.build_connector_meta(scheduler_output) scheduler_output.kv_connector_metadata = meta # Build the connector meta for ECConnector if self.ec_connector is not None: - ec_meta: ECConnectorMetadata = self.ec_connector.build_connector_meta( - scheduler_output) + ec_meta: ECConnectorMetadata = self.ec_connector.build_connector_meta(scheduler_output) scheduler_output.ec_connector_metadata = ec_meta with record_function_or_nullcontext("schedule: update_after_schedule"): @@ -639,8 +610,8 @@ class RecomputeScheduler(Scheduler): outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) spec_decoding_stats: SpecDecodingStats | None = None kv_connector_stats: KVConnectorStats | None = ( - kv_connector_output.kv_connector_stats - if kv_connector_output else None) + kv_connector_output.kv_connector_stats if kv_connector_output else None + ) if kv_connector_stats and self.connector: kv_stats = self.connector.get_kv_connector_stats() if kv_stats: @@ -651,8 +622,7 @@ class RecomputeScheduler(Scheduler): # These blocks contain externally computed tokens that failed to # load. Identify affected requests and adjust their computed token # count to trigger recomputation of the invalid blocks. - failed_kv_load_req_ids = self._handle_invalid_blocks( - kv_connector_output.invalid_block_ids) + failed_kv_load_req_ids = self._handle_invalid_blocks(kv_connector_output.invalid_block_ids) # return recomputed requests as EngineCoreOutput if scheduler_output.recomputed_reqs is not None: @@ -663,7 +633,8 @@ class RecomputeScheduler(Scheduler): finish_reason=FinishReason.STOP, new_token_ids=[req_info.output_token_ids[-1]], stop_reason="recomputed", - )) + ) + ) # NOTE(woosuk): As len(num_scheduled_tokens) can be up to 1K or more, # the below loop can be a performance bottleneck. We should do our best @@ -683,11 +654,9 @@ class RecomputeScheduler(Scheduler): continue req_index = model_runner_output.req_id_to_index[req_id] - generated_token_ids = (sampled_token_ids[req_index] - if sampled_token_ids else []) + generated_token_ids = sampled_token_ids[req_index] if sampled_token_ids else [] - scheduled_spec_token_ids = ( - scheduler_output.scheduled_spec_decode_tokens.get(req_id)) + scheduled_spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(req_id) if scheduled_spec_token_ids: num_draft_tokens = len(scheduled_spec_token_ids) num_accepted = len(generated_token_ids) - 1 @@ -717,15 +686,13 @@ class RecomputeScheduler(Scheduler): # Check for stop and update request status. if new_token_ids: - new_token_ids, stopped = self._update_request_with_output( - request, new_token_ids) + new_token_ids, stopped = self._update_request_with_output(request, new_token_ids) # Stop checking for pooler models. pooler_output = None if pooler_outputs: pooler_output = pooler_outputs[req_index] - stopped = check_stop(request, self.max_model_len, - pooler_output) + stopped = check_stop(request, self.max_model_len, pooler_output) if stopped: kv_transfer_params = self._free_request(request) @@ -735,19 +702,14 @@ class RecomputeScheduler(Scheduler): stopped_preempted_reqs.add(request) # Extract sample logprobs if needed. - if (request.sampling_params is not None - and request.sampling_params.logprobs is not None - and logprobs): - new_logprobs = logprobs.slice_request(req_index, - len(new_token_ids)) + if request.sampling_params is not None and request.sampling_params.logprobs is not None and logprobs: + new_logprobs = logprobs.slice_request(req_index, len(new_token_ids)) - if new_token_ids and self.structured_output_manager.should_advance( - request): + if new_token_ids and self.structured_output_manager.should_advance(request): struct_output_request = request.structured_output_request assert struct_output_request is not None assert struct_output_request.grammar is not None - struct_output_request.grammar.accept_tokens( - req_id, new_token_ids) + struct_output_request.grammar.accept_tokens(req_id, new_token_ids) if num_nans_in_logits is not None and req_id in num_nans_in_logits: request.num_nans_in_logits = num_nans_in_logits[req_id] @@ -770,7 +732,8 @@ class RecomputeScheduler(Scheduler): trace_headers=request.trace_headers, num_cached_tokens=request.num_cached_tokens, num_nans_in_logits=request.num_nans_in_logits, - )) + ) + ) else: # Invariant: EngineCore returns no partial prefill outputs. assert not prompt_logprobs_tensors @@ -805,10 +768,7 @@ class RecomputeScheduler(Scheduler): # Create EngineCoreOutputs for all clients that have requests with # outputs in this step. - engine_core_outputs = { - client_index: EngineCoreOutputs(outputs=outs) - for client_index, outs in outputs.items() - } + engine_core_outputs = {client_index: EngineCoreOutputs(outputs=outs) for client_index, outs in outputs.items()} finished_req_ids = self.finished_req_ids_dict if finished_req_ids: @@ -819,12 +779,10 @@ class RecomputeScheduler(Scheduler): if (eco := engine_core_outputs.get(client_index)) is not None: eco.finished_requests = finished_set else: - engine_core_outputs[client_index] = EngineCoreOutputs( - finished_requests=finished_set) + engine_core_outputs[client_index] = EngineCoreOutputs(finished_requests=finished_set) finished_req_ids.clear() - if (stats := self.make_stats(spec_decoding_stats, - kv_connector_stats)) is not None: + if (stats := self.make_stats(spec_decoding_stats, kv_connector_stats)) is not None: # Return stats to only one of the front-ends. if (eco := next(iter(engine_core_outputs.values()), None)) is None: # We must return the stats even if there are no request @@ -836,6 +794,5 @@ class RecomputeScheduler(Scheduler): class AsyncRecomputeScheduler(AsyncScheduler, RecomputeScheduler): - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/vllm_ascend/core/scheduler_dynamic_batch.py b/vllm_ascend/core/scheduler_dynamic_batch.py index 0b78f34f..5f0898ce 100644 --- a/vllm_ascend/core/scheduler_dynamic_batch.py +++ b/vllm_ascend/core/scheduler_dynamic_batch.py @@ -16,7 +16,6 @@ # import os import time -from typing import Optional import pandas as pd from vllm.config import VllmConfig @@ -25,8 +24,7 @@ from vllm.logger import logger from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput -from vllm.v1.core.sched.request_queue import (SchedulingPolicy, - create_request_queue) +from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.engine import EngineCoreEventType from vllm.v1.kv_cache_interface import KVCacheConfig @@ -43,8 +41,9 @@ class BudgetRefiner: if not self.enabled: return logger.info( - "Dynamic batch is enabled with SLO limit: {}, and chunked prefill is forced to be activated because dynamic batch relies on it" - .format(str(slo_limit))) + "Dynamic batch is enabled with SLO limit: {}, and chunked prefill is " + "forced to be activated because dynamic batch relies on it".format(str(slo_limit)) + ) self.lookup: dict[tuple[int, int], int] = {} self.context_keys: set[int] = set() self.dnum_keys: set[int] = set() @@ -61,19 +60,20 @@ class BudgetRefiner: "The dynamic batching feature requires the lookup table " "'profile_table.csv', but it was not found at '%s'. " "Please download the corresponding table file.", - table_file_path) + table_file_path, + ) self.enabled = False return else: df = pd.read_csv(table_file_path) - grouped = df.groupby(['ctx_len', 'd_num']) + grouped = df.groupby(["ctx_len", "d_num"]) for (ctx_len, d_num), group in grouped: - valid = group[group['cost'] <= slo_limit] + valid = group[group["cost"] <= slo_limit] if not valid.empty: - max_row = valid.loc[valid['chunk_size'].idxmax()] + max_row = valid.loc[valid["chunk_size"].idxmax()] assert isinstance(ctx_len, int), "ctx_len must be an integer" assert isinstance(d_num, int), "d_num must be an integer" - self.lookup[(ctx_len, d_num)] = int(max_row['chunk_size']) + self.lookup[(ctx_len, d_num)] = int(max_row["chunk_size"]) self.context_keys.add(ctx_len) self.dnum_keys.add(d_num) self.context_keys = set(sorted(self.context_keys)) @@ -97,7 +97,10 @@ class BudgetRefiner: logger.warn(f"Table miss for ctx,dnum{aligned_ctx, aligned_dnum}") budget = self.default_budget # For debug. - # logger.info(f"budget {budget}, ctx,dnum {aligned_ctx, aligned_dnum}, raw ctx,dnum {num_deocde_tokens, num_decode}") + # logger.info( + # f"budget {budget}, ctx,dnum {aligned_ctx, aligned_dnum}, " + # f"raw ctx,dnum {num_deocde_tokens, num_decode}" + # ) return budget def refine_budget(self, running_request, budget): @@ -106,9 +109,8 @@ class BudgetRefiner: return budget # assume all running request will be scheduled. num_decode_token_lst = [ - req.num_tokens_with_spec \ - for req in running_request \ - if req.num_computed_tokens >= req.num_prompt_tokens ] + req.num_tokens_with_spec for req in running_request if req.num_computed_tokens >= req.num_prompt_tokens + ] num_decode = len(num_decode_token_lst) if num_decode <= 0: return budget @@ -125,18 +127,25 @@ class SchedulerDynamicBatch(Scheduler): vllm_config: VllmConfig, kv_cache_config: KVCacheConfig, structured_output_manager: StructuredOutputManager, - block_size: Optional[int] = None, + block_size: int | None = None, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, include_finished_set: bool = False, log_stats: bool = False, ) -> None: - super().__init__(vllm_config, kv_cache_config, - structured_output_manager, block_size, mm_registry, - include_finished_set, log_stats) + super().__init__( + vllm_config, + kv_cache_config, + structured_output_manager, + block_size, + mm_registry, + include_finished_set, + log_stats, + ) self.running: list[Request] = [] self.budget_refiner = BudgetRefiner( default_budget=self.scheduler_config.max_num_batched_tokens, - slo_limit=self.scheduler_config.SLO_limits_for_dynamic_batch) + slo_limit=self.scheduler_config.SLO_limits_for_dynamic_batch, + ) def schedule(self) -> SchedulerOutput: # NOTE: This scheduling algorithm is developed based on the "super.schedule()" @@ -159,20 +168,13 @@ class SchedulerDynamicBatch(Scheduler): req_to_new_blocks: dict[str, KVCacheBlocks] = {} num_scheduled_tokens: dict[str, int] = {} token_budget = self.max_num_scheduled_tokens - token_budget = self.budget_refiner.refine_budget( - self.running, token_budget) + token_budget = self.budget_refiner.refine_budget(self.running, token_budget) # NOTE: We move the prefill requests to the end of the self.running # list and keep the relative order unchanged. This rearrangement makes this # scheduling algorithm a strict decode-first chunked prefills. - d_lst = [ - req for req in self.running - if req.num_computed_tokens >= req.num_prompt_tokens - ] - p_lst = [ - req for req in self.running - if req.num_computed_tokens < req.num_prompt_tokens - ] + d_lst = [req for req in self.running if req.num_computed_tokens >= req.num_prompt_tokens] + p_lst = [req for req in self.running if req.num_computed_tokens < req.num_prompt_tokens] self.running = d_lst + p_lst # Encoder-related. @@ -189,30 +191,26 @@ class SchedulerDynamicBatch(Scheduler): while req_index < len(self.running) and token_budget > 0: request = self.running[req_index] - num_new_tokens = (request.num_tokens_with_spec + - request.num_output_placeholders - - request.num_computed_tokens) - if (0 < self.scheduler_config.long_prefill_token_threshold < - num_new_tokens): - num_new_tokens = ( - self.scheduler_config.long_prefill_token_threshold) + num_new_tokens = ( + request.num_tokens_with_spec + request.num_output_placeholders - request.num_computed_tokens + ) + if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens: + num_new_tokens = self.scheduler_config.long_prefill_token_threshold num_new_tokens = min(num_new_tokens, token_budget) # Make sure the input position does not exceed the max model len. # This is necessary when using spec decoding. - num_new_tokens = min( - num_new_tokens, - self.max_model_len - 1 - request.num_computed_tokens) + num_new_tokens = min(num_new_tokens, self.max_model_len - 1 - request.num_computed_tokens) # Schedule encoder inputs. encoder_inputs_to_schedule = None new_encoder_compute_budget = encoder_compute_budget if request.has_encoder_inputs: - (encoder_inputs_to_schedule, num_new_tokens, - new_encoder_compute_budget - ) = self._try_schedule_encoder_inputs( - request, request.num_computed_tokens, num_new_tokens, - encoder_compute_budget) + (encoder_inputs_to_schedule, num_new_tokens, new_encoder_compute_budget) = ( + self._try_schedule_encoder_inputs( + request, request.num_computed_tokens, num_new_tokens, encoder_compute_budget + ) + ) if num_new_tokens == 0: # The request cannot be scheduled because one of the following @@ -231,9 +229,8 @@ class SchedulerDynamicBatch(Scheduler): while True: new_blocks = self.kv_cache_manager.allocate_slots( - request, - num_new_tokens, - num_lookahead_tokens=self.num_lookahead_tokens) + request, num_new_tokens, num_lookahead_tokens=self.num_lookahead_tokens + ) if new_blocks is None: # The request cannot be scheduled. # Preempt the lowest-priority request. @@ -253,8 +250,7 @@ class SchedulerDynamicBatch(Scheduler): preempted_req.status = RequestStatus.PREEMPTED preempted_req.num_computed_tokens = 0 if self.log_stats: - preempted_req.record_event( - EngineCoreEventType.PREEMPTED, scheduled_timestamp) + preempted_req.record_event(EngineCoreEventType.PREEMPTED, scheduled_timestamp) self.waiting.prepend_request(preempted_req) preempted_reqs.append(preempted_req) @@ -279,19 +275,15 @@ class SchedulerDynamicBatch(Scheduler): # Speculative decode related. if request.spec_token_ids: - num_scheduled_spec_tokens = (num_new_tokens + - request.num_computed_tokens - - request.num_tokens) + num_scheduled_spec_tokens = num_new_tokens + request.num_computed_tokens - request.num_tokens if num_scheduled_spec_tokens > 0: # Trim spec_token_ids list to num_scheduled_spec_tokens. del request.spec_token_ids[num_scheduled_spec_tokens:] - scheduled_spec_decode_tokens[request.request_id] = ( - request.spec_token_ids) + scheduled_spec_decode_tokens[request.request_id] = request.spec_token_ids # Encoder-related. if encoder_inputs_to_schedule: - scheduled_encoder_inputs[request.request_id] = ( - encoder_inputs_to_schedule) + scheduled_encoder_inputs[request.request_id] = encoder_inputs_to_schedule # Allocate the encoder cache. for i in encoder_inputs_to_schedule: self.encoder_cache_manager.allocate(request, i) @@ -301,8 +293,10 @@ class SchedulerDynamicBatch(Scheduler): scheduled_loras: set[int] = set() if self.lora_config: scheduled_loras = set( - req.lora_request.lora_int_id for req in scheduled_running_reqs - if req.lora_request and req.lora_request.lora_int_id > 0) + req.lora_request.lora_int_id + for req in scheduled_running_reqs + if req.lora_request and req.lora_request.lora_int_id > 0 + ) assert len(scheduled_loras) <= self.lora_config.max_loras # Use a temporary RequestQueue to collect requests that need to be @@ -323,9 +317,7 @@ class SchedulerDynamicBatch(Scheduler): if is_ready: request.status = RequestStatus.WAITING else: - logger.debug( - "%s is still in WAITING_FOR_REMOTE_KVS state.", - request.request_id) + logger.debug("%s is still in WAITING_FOR_REMOTE_KVS state.", request.request_id) self.waiting.pop_request() skipped_waiting_requests.prepend_request(request) continue @@ -343,9 +335,14 @@ class SchedulerDynamicBatch(Scheduler): # Check that adding the request still respects the max_loras # constraint. - if (self.lora_config and request.lora_request and - (len(scheduled_loras) == self.lora_config.max_loras and - request.lora_request.lora_int_id not in scheduled_loras)): + if ( + self.lora_config + and request.lora_request + and ( + len(scheduled_loras) == self.lora_config.max_loras + and request.lora_request.lora_int_id not in scheduled_loras + ) + ): # Scheduling would exceed max_loras, skip. self.waiting.pop_request() skipped_waiting_requests.prepend_request(request) @@ -357,15 +354,15 @@ class SchedulerDynamicBatch(Scheduler): # Get already-cached tokens. if request.num_computed_tokens == 0: # Get locally-cached tokens. - new_computed_blocks, num_new_local_computed_tokens = \ - self.kv_cache_manager.get_computed_blocks( - request) + new_computed_blocks, num_new_local_computed_tokens = self.kv_cache_manager.get_computed_blocks( + request + ) # Get externally-cached tokens if using a KVConnector. if self.connector is not None: - num_external_computed_tokens, load_kv_async = ( - self.connector.get_num_new_matched_tokens( - request, num_new_local_computed_tokens)) + num_external_computed_tokens, load_kv_async = self.connector.get_num_new_matched_tokens( + request, num_new_local_computed_tokens + ) if num_external_computed_tokens is None: # The request cannot be scheduled because @@ -376,13 +373,11 @@ class SchedulerDynamicBatch(Scheduler): continue # Total computed tokens (local + external). - num_computed_tokens = (num_new_local_computed_tokens + - num_external_computed_tokens) + num_computed_tokens = num_new_local_computed_tokens + num_external_computed_tokens # KVTransfer: WAITING reqs have num_computed_tokens > 0 # after async KV recvs are completed. else: - new_computed_blocks = ( - self.kv_cache_manager.create_empty_block_list()) + new_computed_blocks = self.kv_cache_manager.create_empty_block_list() num_new_local_computed_tokens = 0 num_computed_tokens = request.num_computed_tokens @@ -399,15 +394,12 @@ class SchedulerDynamicBatch(Scheduler): # `request.num_prompt_tokens` to consider the resumed # requests, which have output tokens. num_new_tokens = request.num_tokens - num_computed_tokens - if (0 < self.scheduler_config.long_prefill_token_threshold - < num_new_tokens): - num_new_tokens = ( - self.scheduler_config.long_prefill_token_threshold) + if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens: + num_new_tokens = self.scheduler_config.long_prefill_token_threshold # chunked prefill has to be enabled explicitly to allow # pooling requests to be chunked - if not self.scheduler_config.enable_chunked_prefill and \ - num_new_tokens > token_budget: + if not self.scheduler_config.enable_chunked_prefill and num_new_tokens > token_budget: self.waiting.pop_request() skipped_waiting_requests.prepend_request(request) continue @@ -417,11 +409,11 @@ class SchedulerDynamicBatch(Scheduler): # Schedule encoder inputs. if request.has_encoder_inputs: - (encoder_inputs_to_schedule, num_new_tokens, - new_encoder_compute_budget, - _) = self._try_schedule_encoder_inputs( - request, num_computed_tokens, num_new_tokens, - encoder_compute_budget) + (encoder_inputs_to_schedule, num_new_tokens, new_encoder_compute_budget, _) = ( + self._try_schedule_encoder_inputs( + request, num_computed_tokens, num_new_tokens, encoder_compute_budget + ) + ) if num_new_tokens == 0: # The request cannot be scheduled. break @@ -431,9 +423,7 @@ class SchedulerDynamicBatch(Scheduler): # extra block gets allocated which # creates a mismatch between the number # of local and remote blocks. - effective_lookahead_tokens = (0 if request.num_computed_tokens - == 0 else - self.num_lookahead_tokens) + effective_lookahead_tokens = 0 if request.num_computed_tokens == 0 else self.num_lookahead_tokens # Determine if we need to allocate cross-attention blocks. if self.is_encoder_decoder and request.has_encoder_inputs: @@ -441,8 +431,7 @@ class SchedulerDynamicBatch(Scheduler): # always padded to the maximum length. If we support other # encoder-decoder models, this will need to be updated if we # want to only allocate what is needed. - num_encoder_tokens =\ - self.scheduler_config.max_num_encoder_input_tokens + num_encoder_tokens = self.scheduler_config.max_num_encoder_input_tokens else: num_encoder_tokens = 0 @@ -484,20 +473,17 @@ class SchedulerDynamicBatch(Scheduler): req_index += 1 self.running.append(request) if self.log_stats: - request.record_event(EngineCoreEventType.SCHEDULED, - scheduled_timestamp) + request.record_event(EngineCoreEventType.SCHEDULED, scheduled_timestamp) if request.status == RequestStatus.WAITING: scheduled_new_reqs.append(request) elif request.status == RequestStatus.PREEMPTED: scheduled_resumed_reqs.append(request) else: - raise RuntimeError( - f"Invalid request status: {request.status}") + raise RuntimeError(f"Invalid request status: {request.status}") if self.lora_config and request.lora_request: scheduled_loras.add(request.lora_request.lora_int_id) - req_to_new_blocks[request.request_id] = ( - self.kv_cache_manager.get_blocks(request.request_id)) + req_to_new_blocks[request.request_id] = self.kv_cache_manager.get_blocks(request.request_id) num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens request.status = RequestStatus.RUNNING @@ -507,8 +493,7 @@ class SchedulerDynamicBatch(Scheduler): request.num_cached_tokens = num_computed_tokens # Encoder-related. if encoder_inputs_to_schedule: - scheduled_encoder_inputs[request.request_id] = ( - encoder_inputs_to_schedule) + scheduled_encoder_inputs[request.request_id] = encoder_inputs_to_schedule # Allocate the encoder cache. for i in encoder_inputs_to_schedule: self.encoder_cache_manager.allocate(request, i) @@ -526,22 +511,17 @@ class SchedulerDynamicBatch(Scheduler): # Since some requests in the RUNNING queue may not be scheduled in # this step, the total number of scheduled requests can be smaller than # len(self.running). - assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + - len(scheduled_running_reqs) <= len(self.running)) + assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len(scheduled_running_reqs) <= len(self.running) # Get the longest common prefix among all requests in the running queue. # This can be potentially used for cascade attention. - num_common_prefix_blocks = [0] * len( - self.kv_cache_config.kv_cache_groups) + num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups) if self.running: any_request = self.running[0] - num_common_prefix_blocks = ( - self.kv_cache_manager.get_num_common_prefix_blocks( - any_request.request_id)) + num_common_prefix_blocks = self.kv_cache_manager.get_num_common_prefix_blocks(any_request.request_id) # Construct the scheduler output. new_reqs_data = [ - NewRequestData.from_request( - req, req_to_new_blocks[req.request_id].get_block_ids()) + NewRequestData.from_request(req, req_to_new_blocks[req.request_id].get_block_ids()) for req in scheduled_new_reqs ] cached_reqs_data = self._make_cached_request_data( @@ -564,8 +544,7 @@ class SchedulerDynamicBatch(Scheduler): # It contains the request IDs that are finished in between # the previous and the current steps. finished_req_ids=self.finished_req_ids, - free_encoder_mm_hashes=self.encoder_cache_manager. - get_freed_mm_hashes(), + free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(), ) # NOTE(Kuntai): this function is designed for multiple purposes: diff --git a/vllm_ascend/distributed/device_communicators/npu_communicator.py b/vllm_ascend/distributed/device_communicators/npu_communicator.py index 7c14befa..6950c87a 100644 --- a/vllm_ascend/distributed/device_communicators/npu_communicator.py +++ b/vllm_ascend/distributed/device_communicators/npu_communicator.py @@ -14,61 +14,50 @@ # limitations under the License. # This file is a part of the vllm-ascend project. # -from typing import List, Optional import torch import torch.distributed as dist -from vllm.distributed.device_communicators.base_device_communicator import \ - DeviceCommunicatorBase +from vllm.distributed.device_communicators.base_device_communicator import DeviceCommunicatorBase class NPUCommunicator(DeviceCommunicatorBase): - - def __init__(self, - cpu_group: dist.ProcessGroup, - device: Optional[torch.device] = None, - device_group: Optional[dist.ProcessGroup] = None, - unique_name: str = ""): + def __init__( + self, + cpu_group: dist.ProcessGroup, + device: torch.device | None = None, + device_group: dist.ProcessGroup | None = None, + unique_name: str = "", + ): super().__init__(cpu_group, device, device_group, unique_name) # TODO(hz): Refer to CudaCommunicator's implementation to integrate PyHcclCommunicator # init device according to rank self.device = torch.npu.current_device() - def all_to_all(self, - input_: torch.Tensor, - scatter_dim: int = 0, - gather_dim: int = -1, - scatter_sizes: Optional[List[int]] = None, - gather_sizes: Optional[List[int]] = None) -> torch.Tensor: - + def all_to_all( + self, + input_: torch.Tensor, + scatter_dim: int = 0, + gather_dim: int = -1, + scatter_sizes: list[int] | None = None, + gather_sizes: list[int] | None = None, + ) -> torch.Tensor: if scatter_dim < 0: scatter_dim += input_.dim() if gather_dim < 0: gather_dim += input_.dim() if scatter_sizes is not None and gather_sizes is not None: - input_list = [ - t.contiguous() - for t in torch.split(input_, scatter_sizes, scatter_dim) - ] + input_list = [t.contiguous() for t in torch.split(input_, scatter_sizes, scatter_dim)] output_list = [] tensor_shape_base = input_list[self.rank].size() for i in range(self.world_size): tensor_shape = list(tensor_shape_base) tensor_shape[gather_dim] = gather_sizes[i] - output_list.append( - torch.empty(tensor_shape, - dtype=input_.dtype, - device=input_.device)) + output_list.append(torch.empty(tensor_shape, dtype=input_.dtype, device=input_.device)) else: - input_list = [ - t.contiguous() for t in torch.tensor_split( - input_, self.world_size, scatter_dim) - ] - output_list = [ - torch.empty_like(input_list[i]) for i in range(self.world_size) - ] + input_list = [t.contiguous() for t in torch.tensor_split(input_, self.world_size, scatter_dim)] + output_list = [torch.empty_like(input_list[i]) for i in range(self.world_size)] dist.all_to_all(output_list, input_list, group=self.device_group) output_tensor = torch.cat(output_list, dim=gather_dim).contiguous() diff --git a/vllm_ascend/distributed/device_communicators/pyhccl.py b/vllm_ascend/distributed/device_communicators/pyhccl.py index 984ece79..220c48f3 100644 --- a/vllm_ascend/distributed/device_communicators/pyhccl.py +++ b/vllm_ascend/distributed/device_communicators/pyhccl.py @@ -15,7 +15,6 @@ # limitations under the License. # -from typing import Optional, Union import torch import torch.distributed as dist @@ -24,18 +23,23 @@ from vllm.distributed.utils import StatelessProcessGroup from vllm.logger import logger from vllm_ascend.distributed.device_communicators.pyhccl_wrapper import ( - HCCLLibrary, aclrtStream_t, buffer_type, hcclComm_t, hcclDataTypeEnum, - hcclRedOpTypeEnum, hcclUniqueId) + HCCLLibrary, + aclrtStream_t, + buffer_type, + hcclComm_t, + hcclDataTypeEnum, + hcclRedOpTypeEnum, + hcclUniqueId, +) from vllm_ascend.utils import current_stream class PyHcclCommunicator: - def __init__( self, - group: Union[ProcessGroup, StatelessProcessGroup], - device: Union[int, str, torch.device], - library_path: Optional[str] = None, + group: ProcessGroup | StatelessProcessGroup, + device: int | str | torch.device, + library_path: str | None = None, ): """ Args: @@ -52,7 +56,8 @@ class PyHcclCommunicator: if not isinstance(group, StatelessProcessGroup): assert dist.is_initialized() assert dist.get_backend(group) != dist.Backend.HCCL, ( - "PyHcclCommunicator should be attached to a non-HCCL group.") + "PyHcclCommunicator should be attached to a non-HCCL group." + ) # note: this rank is the rank in the group self.rank = dist.get_rank(group) self.world_size = dist.get_world_size(group) @@ -113,8 +118,7 @@ class PyHcclCommunicator: # `torch.npu.device` is a context manager that changes the # current npu device to the specified one with torch.npu.device(device): - self.comm: hcclComm_t = self.hccl.hcclCommInitRank( - self.world_size, self.unique_id, self.rank) + self.comm: hcclComm_t = self.hccl.hcclCommInitRank(self.world_size, self.unique_id, self.rank) stream = current_stream() # A small all_reduce for warmup. @@ -123,43 +127,48 @@ class PyHcclCommunicator: stream.synchronize() del data - def all_reduce(self, - in_tensor: torch.Tensor, - op: ReduceOp = ReduceOp.SUM, - stream=None) -> torch.Tensor: + def all_reduce(self, in_tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None) -> torch.Tensor: if self.disabled: return None # hccl communicator created on a specific device # will only work on tensors on the same device # otherwise it will cause "illegal memory access" assert in_tensor.device == self.device, ( - f"this hccl communicator is created to work on {self.device}, " - f"but the input tensor is on {in_tensor.device}") + f"this hccl communicator is created to work on {self.device}, but the input tensor is on {in_tensor.device}" + ) out_tensor = torch.empty_like(in_tensor) if stream is None: stream = current_stream() - self.hccl.hcclAllReduce(buffer_type(in_tensor.data_ptr()), - buffer_type(out_tensor.data_ptr()), - in_tensor.numel(), - hcclDataTypeEnum.from_torch(in_tensor.dtype), - hcclRedOpTypeEnum.from_torch(op), self.comm, - aclrtStream_t(stream.npu_stream)) + self.hccl.hcclAllReduce( + buffer_type(in_tensor.data_ptr()), + buffer_type(out_tensor.data_ptr()), + in_tensor.numel(), + hcclDataTypeEnum.from_torch(in_tensor.dtype), + hcclRedOpTypeEnum.from_torch(op), + self.comm, + aclrtStream_t(stream.npu_stream), + ) return out_tensor def broadcast(self, tensor: torch.Tensor, src: int, stream=None): if self.disabled: return assert tensor.device == self.device, ( - f"this hccl communicator is created to work on {self.device}, " - f"but the input tensor is on {tensor.device}") + f"this hccl communicator is created to work on {self.device}, but the input tensor is on {tensor.device}" + ) if stream is None: stream = current_stream() if src == self.rank: buffer = buffer_type(tensor.data_ptr()) else: buffer = buffer_type(tensor.data_ptr()) - self.hccl.hcclBroadcast(buffer, tensor.numel(), - hcclDataTypeEnum.from_torch(tensor.dtype), src, - self.comm, aclrtStream_t(stream.npu_stream)) + self.hccl.hcclBroadcast( + buffer, + tensor.numel(), + hcclDataTypeEnum.from_torch(tensor.dtype), + src, + self.comm, + aclrtStream_t(stream.npu_stream), + ) diff --git a/vllm_ascend/distributed/device_communicators/pyhccl_wrapper.py b/vllm_ascend/distributed/device_communicators/pyhccl_wrapper.py index 3435cc25..2808891b 100644 --- a/vllm_ascend/distributed/device_communicators/pyhccl_wrapper.py +++ b/vllm_ascend/distributed/device_communicators/pyhccl_wrapper.py @@ -18,7 +18,7 @@ import ctypes import platform from dataclasses import dataclass -from typing import Any, Dict, List, Optional +from typing import Any import torch from torch.distributed import ReduceOp @@ -107,69 +107,74 @@ class hcclRedOpTypeEnum: class Function: name: str restype: Any - argtypes: List[Any] + argtypes: list[Any] class HCCLLibrary: exported_functions = [ # const char* HcclGetErrorString(HcclResult code); Function("HcclGetErrorString", ctypes.c_char_p, [hcclResult_t]), - # HcclResult HcclGetRootInfo(HcclRootInfo *rootInfo); - Function("HcclGetRootInfo", hcclResult_t, - [ctypes.POINTER(hcclUniqueId)]), - + Function("HcclGetRootInfo", hcclResult_t, [ctypes.POINTER(hcclUniqueId)]), # HcclResult HcclCommInitRootInfo( # uint32_t nRanks, const HcclRootInfo *rootInfo, uint32_t rank, HcclComm *comm); # note that HcclComm is a pointer type, so the last argument is a pointer to a pointer - Function("HcclCommInitRootInfo", hcclResult_t, [ - ctypes.c_int, - ctypes.POINTER(hcclUniqueId), - ctypes.c_int, - ctypes.POINTER(hcclComm_t), - ]), - + Function( + "HcclCommInitRootInfo", + hcclResult_t, + [ + ctypes.c_int, + ctypes.POINTER(hcclUniqueId), + ctypes.c_int, + ctypes.POINTER(hcclComm_t), + ], + ), # HcclResult HcclAllReduce( # void *sendBuf, void *recvBuf, uint64_t count, # HcclDataType dataType, HcclReduceOp op, HcclComm comm, # aclrtStream stream); - Function("HcclAllReduce", hcclResult_t, [ - buffer_type, - buffer_type, - ctypes.c_size_t, - hcclDataType_t, - hcclRedOp_t, - hcclComm_t, - aclrtStream_t, - ]), - + Function( + "HcclAllReduce", + hcclResult_t, + [ + buffer_type, + buffer_type, + ctypes.c_size_t, + hcclDataType_t, + hcclRedOp_t, + hcclComm_t, + aclrtStream_t, + ], + ), # HcclResult HcclBroadcast( # void *buf, uint64_t count, # HcclDataType dataType, uint32_t root, # HcclComm comm, aclrtStream stream); - Function("HcclBroadcast", hcclResult_t, [ - buffer_type, - ctypes.c_size_t, - hcclDataType_t, - ctypes.c_int, - hcclComm_t, - aclrtStream_t, - ]), - + Function( + "HcclBroadcast", + hcclResult_t, + [ + buffer_type, + ctypes.c_size_t, + hcclDataType_t, + ctypes.c_int, + hcclComm_t, + aclrtStream_t, + ], + ), # HcclResult HcclCommDestroy(HcclComm comm); Function("HcclCommDestroy", hcclResult_t, [hcclComm_t]), ] # class attribute to store the mapping from the path to the library # to avoid loading the same library multiple times - path_to_library_cache: Dict[str, Any] = {} + path_to_library_cache: dict[str, Any] = {} # class attribute to store the mapping from library path # to the correspongding directory - path_to_dict_mapping: Dict[str, Dict[str, Any]] = {} - - def __init__(self, so_file: Optional[str] = None): + path_to_dict_mapping: dict[str, dict[str, Any]] = {} + def __init__(self, so_file: str | None = None): so_file = so_file or find_hccl_library() try: @@ -185,12 +190,14 @@ class HCCLLibrary: "or it does not support the current platform %s. " "If you already have the library, please set the " "environment variable HCCL_SO_PATH" - " to point to the correct hccl library path.", so_file, - platform.platform()) + " to point to the correct hccl library path.", + so_file, + platform.platform(), + ) raise e if so_file not in HCCLLibrary.path_to_dict_mapping: - _funcs: Dict[str, Any] = {} + _funcs: dict[str, Any] = {} for func in HCCLLibrary.exported_functions: f = getattr(self.lib, func.name) f.restype = func.restype @@ -209,34 +216,37 @@ class HCCLLibrary: def hcclGetUniqueId(self) -> hcclUniqueId: unique_id = hcclUniqueId() - self.HCCL_CHECK(self._funcs["HcclGetRootInfo"]( - ctypes.byref(unique_id))) + self.HCCL_CHECK(self._funcs["HcclGetRootInfo"](ctypes.byref(unique_id))) return unique_id - def hcclCommInitRank(self, world_size: int, unique_id: hcclUniqueId, - rank: int) -> hcclComm_t: + def hcclCommInitRank(self, world_size: int, unique_id: hcclUniqueId, rank: int) -> hcclComm_t: comm = hcclComm_t() - self.HCCL_CHECK(self._funcs["HcclCommInitRootInfo"]( - world_size, ctypes.byref(unique_id), rank, ctypes.byref(comm))) + self.HCCL_CHECK( + self._funcs["HcclCommInitRootInfo"](world_size, ctypes.byref(unique_id), rank, ctypes.byref(comm)) + ) return comm - def hcclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type, - count: int, datatype: int, op: int, comm: hcclComm_t, - stream: aclrtStream_t) -> None: + def hcclAllReduce( + self, + sendbuff: buffer_type, + recvbuff: buffer_type, + count: int, + datatype: int, + op: int, + comm: hcclComm_t, + stream: aclrtStream_t, + ) -> None: # `datatype` actually should be `hcclDataType_t` # and `op` should be `hcclRedOp_t` # both are aliases of `ctypes.c_int` # when we pass int to a function, it will be converted to `ctypes.c_int` # by ctypes automatically - self.HCCL_CHECK(self._funcs["HcclAllReduce"](sendbuff, recvbuff, count, - datatype, op, comm, - stream)) + self.HCCL_CHECK(self._funcs["HcclAllReduce"](sendbuff, recvbuff, count, datatype, op, comm, stream)) - def hcclBroadcast(self, buf: buffer_type, count: int, datatype: int, - root: int, comm: hcclComm_t, - stream: aclrtStream_t) -> None: - self.HCCL_CHECK(self._funcs["HcclBroadcast"](buf, count, datatype, - root, comm, stream)) + def hcclBroadcast( + self, buf: buffer_type, count: int, datatype: int, root: int, comm: hcclComm_t, stream: aclrtStream_t + ) -> None: + self.HCCL_CHECK(self._funcs["HcclBroadcast"](buf, count, datatype, root, comm, stream)) def hcclCommDestroy(self, comm: hcclComm_t) -> None: self.HCCL_CHECK(self._funcs["HcclCommDestroy"](comm))