From 329961b37534f7d57f18732509ad4385612903a7 Mon Sep 17 00:00:00 2001 From: SILONG ZENG <2609716663@qq.com> Date: Mon, 19 Jan 2026 08:59:46 +0800 Subject: [PATCH] [Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #2) (#5977) ### What this PR does / why we need it? **Scope of Changes**: | File Path | | :--- | | `vllm_ascend/attention/attention_mask.py` | | `vllm_ascend/attention/attention_v1.py` | | `vllm_ascend/attention/context_parallel/attention_cp.py` | | `vllm_ascend/attention/context_parallel/common_cp.py` | | `vllm_ascend/attention/context_parallel/mla_cp.py` | | `vllm_ascend/attention/utils.py` | | `vllm_ascend/batch_invariant.py` | | `vllm_ascend/device/device_op.py` | | `vllm_ascend/device_allocator/camem.py` | | `vllm_ascend/envs.py` | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: MrZ20 <2609716663@qq.com> --- pyproject.toml | 8 +- vllm_ascend/attention/attention_mask.py | 29 +- vllm_ascend/attention/attention_v1.py | 368 +++++---- .../context_parallel/attention_cp.py | 708 ++++++++---------- .../attention/context_parallel/common_cp.py | 59 +- .../attention/context_parallel/mla_cp.py | 502 ++++++------- vllm_ascend/attention/utils.py | 77 +- vllm_ascend/batch_invariant.py | 15 +- vllm_ascend/device/device_op.py | 29 +- vllm_ascend/device_allocator/camem.py | 84 +-- vllm_ascend/envs.py | 86 +-- 11 files changed, 920 insertions(+), 1045 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b20a7bcc..65206975 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,11 +49,9 @@ line-length = 120 # Folder to be modified exclude = [ "tests/**", - "vllm_ascend/_cann_ops_custom", - "vllm_ascend/attention", + "vllm_ascend/attention/mla_v1.py", + "vllm_ascend/attention/sfa_v1.py", "vllm_ascend/core", - "vllm_ascend/device", - "vllm_ascend/device_allocator", "vllm_ascend/distributed", "vllm_ascend/eplb", "vllm_ascend/kv_offload", @@ -66,8 +64,6 @@ exclude = [ "vllm_ascend/spec_decode", "vllm_ascend/worker", "vllm_ascend/xlite", - "vllm_ascend/envs.py", - "vllm_ascend/batch_invariant.py", ] [tool.ruff.lint] diff --git a/vllm_ascend/attention/attention_mask.py b/vllm_ascend/attention/attention_mask.py index e7823b9e..d76d257c 100644 --- a/vllm_ascend/attention/attention_mask.py +++ b/vllm_ascend/attention/attention_mask.py @@ -21,21 +21,18 @@ from vllm_ascend.utils import singleton def _generate_attn_mask(max_seq_len, dtype): # Construct lower triangle matrix. - mask_flag = torch.ones((max_seq_len, max_seq_len), - dtype=torch.bool).tril_() + mask_flag = torch.ones((max_seq_len, max_seq_len), dtype=torch.bool).tril_() # Create upper triangle matrix used to mark mask positions. mask_flag = ~mask_flag # Currently for fp16 dtype, the mask value should be set to -inf. # TODO: Eliminate this part in the future. - mask_value = float('-inf') if dtype == torch.float16 else 1 - attn_mask = torch.zeros(size=(max_seq_len, max_seq_len), dtype=dtype) \ - .masked_fill_(mask_flag, mask_value) + mask_value = float("-inf") if dtype == torch.float16 else 1 + attn_mask = torch.zeros(size=(max_seq_len, max_seq_len), dtype=dtype).masked_fill_(mask_flag, mask_value) return attn_mask @singleton class AttentionMaskBuilder: - def __init__(self, device: torch.device): self.attn_mask_cache = None self._seq_len_cached = 0 @@ -52,14 +49,13 @@ class AttentionMaskBuilder: assert self.attn_mask_cache is not None, "Something is wrong in generate_attn_mask." if self.attn_mask_cache.dtype != dtype: self.attn_mask_cache = self.attn_mask_cache.to(dtype) - return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous( - ).to(self.device, non_blocking=True) + return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous().to(self.device, non_blocking=True) def get_splitfuse_attn_mask(self) -> torch.Tensor: if self.chunked_prefill_attn_mask is None: - self.chunked_prefill_attn_mask = torch.triu( - torch.ones(2048, - 2048), diagonal=1).to(torch.int8).to(self.device) + self.chunked_prefill_attn_mask = ( + torch.triu(torch.ones(2048, 2048), diagonal=1).to(torch.int8).to(self.device) + ) return self.chunked_prefill_attn_mask def get_mla_mask(self, dtype: torch.dtype) -> torch.Tensor: @@ -68,16 +64,13 @@ class AttentionMaskBuilder: mask_value = torch.finfo(torch.float32).min else: mask_value = 1 - prefill_mask = torch.triu( - torch.ones(512, 512, device=self.device, dtype=dtype), 1) - self.mla_mask = torch.where(prefill_mask == 1, mask_value, - 0).to(dtype) + prefill_mask = torch.triu(torch.ones(512, 512, device=self.device, dtype=dtype), 1) + self.mla_mask = torch.where(prefill_mask == 1, mask_value, 0).to(dtype) return self.mla_mask def get_pcp_mla_mask(self, dtype: torch.dtype): if self.pcp_mla_mask is None or self.pcp_mla_mask.dtype != dtype: - self.pcp_mla_mask = torch.triu( - torch.ones(512, 512, device=self.device, dtype=dtype), 1) + self.pcp_mla_mask = torch.triu(torch.ones(512, 512, device=self.device, dtype=dtype), 1) return self.pcp_mla_mask def get_swa_mask(self, dtype: torch.dtype, sliding_window): @@ -99,4 +92,4 @@ class AttentionMaskBuilder: if get_pcp_group().world_size > 1: return self.get_pcp_mla_mask(model_config.dtype) # Prefill stages use 512x512 mask with appropriate dtype - return self.get_mla_mask(model_config.dtype) \ No newline at end of file + return self.get_mla_mask(model_config.dtype) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index bc1654ec..933fde28 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -17,7 +17,7 @@ from dataclasses import dataclass from enum import Enum -from typing import ClassVar, List, Optional, Tuple, Type +from typing import ClassVar import torch import torch_npu @@ -29,32 +29,49 @@ from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import AttentionSpec, CrossAttentionSpec from vllm_ascend.attention.attention_mask import AttentionMaskBuilder -from vllm_ascend.attention.context_parallel.common_cp import ( - AscendMetadataForDecode, AscendMetadataForPrefill) -from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, - enable_cp, split_decodes_and_prefills, - using_paged_attention) +from vllm_ascend.attention.context_parallel.common_cp import AscendMetadataForDecode, AscendMetadataForPrefill +from vllm_ascend.attention.utils import ( + AscendCommonAttentionMetadata, + enable_cp, + split_decodes_and_prefills, + using_paged_attention, +) from vllm_ascend.compilation.acl_graph import ( - get_draft_graph_params, get_graph_params, - update_draft_graph_params_workspaces, update_graph_params_workspaces) + get_draft_graph_params, + get_graph_params, + update_draft_graph_params_workspaces, + update_graph_params_workspaces, +) from vllm_ascend.device.device_op import DeviceOperator from vllm_ascend.ops.flashcomm2_oshard_manager import flashcomm2_oshard_manager from vllm_ascend.utils import vllm_version_is, weak_ref_tensors # isort: off -if vllm_version_is('0.13.0'): - from vllm.v1.attention.backends.utils import (AttentionCGSupport, - AttentionMetadataBuilder) +if vllm_version_is("0.13.0"): + from vllm.v1.attention.backends.utils import AttentionCGSupport, AttentionMetadataBuilder from vllm.attention.backends.abstract import ( # type: ignore - AttentionBackend, AttentionImpl, AttentionLayer, AttentionType) + AttentionBackend, + AttentionImpl, + AttentionLayer, + AttentionType, + ) from vllm.attention.backends.registry import ( # type: ignore - AttentionBackendEnum, register_backend) + AttentionBackendEnum, + register_backend, + ) else: from vllm.v1.attention.backend import ( # type: ignore - AttentionBackend, AttentionCGSupport, AttentionImpl, AttentionLayer, - AttentionType, AttentionMetadataBuilder) + AttentionBackend, + AttentionCGSupport, + AttentionImpl, + AttentionLayer, + AttentionType, + AttentionMetadataBuilder, + ) from vllm.v1.attention.backends.registry import ( # type: ignore - AttentionBackendEnum, register_backend) + AttentionBackendEnum, + register_backend, + ) # isort: on # default max value of sliding window size @@ -73,18 +90,18 @@ class AscendAttentionBackend(AttentionBackend): return "CUSTOM" if not envs_vllm.VLLM_USE_V2_MODEL_RUNNER else "FLASH_ATTN" @staticmethod - def get_impl_cls() -> Type["AscendAttentionBackendImpl"]: + def get_impl_cls() -> type["AscendAttentionBackendImpl"]: if enable_cp(): - from vllm_ascend.attention.context_parallel.attention_cp import \ - AscendAttentionCPImpl + from vllm_ascend.attention.context_parallel.attention_cp import AscendAttentionCPImpl + return AscendAttentionCPImpl return AscendAttentionBackendImpl @staticmethod def get_builder_cls() -> type["AscendAttentionMetadataBuilder"]: if enable_cp(): - from vllm_ascend.attention.context_parallel.attention_cp import \ - AscendAttentionCPMetadataBuilder + from vllm_ascend.attention.context_parallel.attention_cp import AscendAttentionCPMetadataBuilder + return AscendAttentionCPMetadataBuilder return AscendAttentionMetadataBuilder @@ -94,13 +111,13 @@ class AscendAttentionBackend(AttentionBackend): block_size: int, num_kv_heads: int, head_size: int, - ) -> Tuple[int, ...]: + ) -> tuple[int, ...]: return (2, num_blocks, block_size, num_kv_heads, head_size) @staticmethod def swap_blocks( - src_kv_cache: List[torch.Tensor], - dst_kv_cache: List[torch.Tensor], + src_kv_cache: list[torch.Tensor], + dst_kv_cache: list[torch.Tensor], src_to_dst: torch.Tensor, ) -> None: src_key_cache, src_value_cache = src_kv_cache[0], src_kv_cache[1] @@ -108,14 +125,12 @@ class AscendAttentionBackend(AttentionBackend): src_indices = src_to_dst[:, 0] dst_indices = src_to_dst[:, 1] - dst_key_cache[dst_indices] = src_key_cache[src_indices].to( - dst_key_cache.device) - dst_value_cache[dst_indices] = src_value_cache[src_indices].to( - dst_key_cache.device) + dst_key_cache[dst_indices] = src_key_cache[src_indices].to(dst_key_cache.device) + dst_value_cache[dst_indices] = src_value_cache[src_indices].to(dst_key_cache.device) @staticmethod def copy_blocks( - kv_caches: List[torch.Tensor], + kv_caches: list[torch.Tensor], src_to_dists: torch.Tensor, ) -> None: src_indices = src_to_dists[:, 0] @@ -148,8 +163,9 @@ class AscendMetadata: Contains attention masks, token counts, sequence lengths and KV cache related properties for attention computation. """ + # **************************** Basic Properties ************************** # - attn_mask: Optional[torch.Tensor] = None + attn_mask: torch.Tensor | None = None # Current state of this attention run. attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill @@ -168,12 +184,12 @@ class AscendMetadata: # should simplified these parameters once attention schema in vLLM-Ascend # is unified. seq_lens: torch.Tensor = None - seq_lens_list: List[int] = None # type: ignore - actual_seq_lengths_q: List[int] = None # type: ignore + seq_lens_list: list[int] = None # type: ignore + actual_seq_lengths_q: list[int] = None # type: ignore query_start_loc: torch.Tensor = None # Maximum query length in the batch (None for decoding). - max_query_len: Optional[int] = None + max_query_len: int | None = None # ********************** KV Cache Related Properties ********************* # # Block addresses per sequence (Seq id -> list of physical block). @@ -187,9 +203,9 @@ class AscendMetadata: # (num_tokens,) slot_mapping: torch.Tensor = None # pcp - prefill: Optional[AscendMetadataForPrefill] = None + prefill: AscendMetadataForPrefill | None = None # dcp - decode_meta: Optional[AscendMetadataForDecode] = None + decode_meta: AscendMetadataForDecode | None = None causal: bool = True # runner_type in model_config. @@ -198,7 +214,7 @@ class AscendMetadata: reshape_cache_event: torch.npu.Event = None # sliding window attention mask - swa_mask: Optional[torch.Tensor] = None + swa_mask: torch.Tensor | None = None class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]): @@ -208,6 +224,7 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]): Handles attention mask generation and metadata preparation for Ascend FlashAttention backend. """ + # Does this backend/builder reorder the batch? # If not, set this to None. Otherwise set it to the query # length that will be pulled into the front of the batch. @@ -226,17 +243,19 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]): self.compilation_config = vllm_config.compilation_config self.device = device self.max_num_blocks_per_req = cdiv( - self.model_config.max_model_len, - AscendAttentionBackend.get_supported_block_size()[0]) + self.model_config.max_model_len, AscendAttentionBackend.get_supported_block_size()[0] + ) self.speculative_config = vllm_config.speculative_config self.decode_threshold = 1 if self.speculative_config: spec_token_num = self.speculative_config.num_speculative_tokens self.decode_threshold += spec_token_num - assert self.decode_threshold <= 16, f"decode_threshold exceeded \ + assert self.decode_threshold <= 16, ( + f"decode_threshold exceeded \ npu_fused_infer_attention_score TND layout's limit of 16, \ got {self.decode_threshold}" + ) AscendAttentionMetadataBuilder.reorder_batch_threshold = self.decode_threshold @@ -254,8 +273,7 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]): # @override omitted only because of mypy limitation due to type variable. return AttentionCGSupport.ALWAYS - def reorder_batch(self, input_batch, - scheduler_output: "SchedulerOutput") -> bool: + def reorder_batch(self, input_batch, scheduler_output: "SchedulerOutput") -> bool: return False def build( @@ -266,12 +284,11 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]): ) -> AscendMetadata: num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens - query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[: - num_reqs - + 1] + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[: num_reqs + 1] - num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ - split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold) + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = split_decodes_and_prefills( + common_attn_metadata, decode_threshold=self.decode_threshold + ) block_table = common_attn_metadata.block_table_tensor seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] @@ -283,19 +300,17 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]): attn_state = common_attn_metadata.attn_state # Get attn_mask and swa_mask from singleton AttentionMaskBuilder - attn_mask = self.attn_mask_builder.get_attention_mask( - self.model_config) + attn_mask = self.attn_mask_builder.get_attention_mask(self.model_config) swa_mask = None - is_swa = hasattr(self.model_config.hf_text_config, 'sliding_window') + is_swa = hasattr(self.model_config.hf_text_config, "sliding_window") if self.model_config is not None and is_swa: swa_mask = self.attn_mask_builder.get_swa_mask( - self.model_config.dtype, - self.model_config.hf_text_config.sliding_window) + self.model_config.dtype, self.model_config.hf_text_config.sliding_window + ) # TODO: Yet another unnecessary H2D while we already have a query_start_loc on device - query_start_loc = query_start_loc_cpu.pin_memory().to( - self.device, non_blocking=True) + query_start_loc = query_start_loc_cpu.pin_memory().to(self.device, non_blocking=True) attn_metadata = AscendMetadata( num_actual_tokens=num_actual_tokens, @@ -313,7 +328,8 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]): num_prefills=num_prefills, num_decodes=num_decodes, causal=common_attn_metadata.causal, - model_runner_type=self.model_config.runner_type) + model_runner_type=self.model_config.runner_type, + ) return attn_metadata def build_for_graph_capture( @@ -321,9 +337,7 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]): common_attn_metadata: AscendCommonAttentionMetadata, attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly, ): - - if attn_state in (AscendAttentionState.DecodeOnly, - AscendAttentionState.ChunkedPrefill): + if attn_state in (AscendAttentionState.DecodeOnly, AscendAttentionState.ChunkedPrefill): attn_metadata = self.build( common_prefix_len=0, common_attn_metadata=common_attn_metadata, @@ -338,19 +352,18 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]): class AscendAttentionBackendImpl(AttentionImpl): - def __init__( self, num_heads: int, head_size: int, scale: float, num_kv_heads: int, - alibi_slopes: Optional[List[float]], - sliding_window: Optional[int], + alibi_slopes: list[float] | None, + sliding_window: int | None, kv_cache_dtype: str, - logits_soft_cap: Optional[float], + logits_soft_cap: float | None, attn_type: str, - kv_sharing_target_layer_name: Optional[str], + kv_sharing_target_layer_name: str | None, **kwargs, ) -> None: self.vllm_config = get_current_vllm_config() @@ -362,9 +375,7 @@ class AscendAttentionBackendImpl(AttentionImpl): self.kv_cache_dtype = kv_cache_dtype self.sliding_window = sliding_window if alibi_slopes is not None: - alibi_slopes = torch.tensor(alibi_slopes, - dtype=torch.float32, - device="npu") + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32, device="npu") self.alibi_slopes = alibi_slopes self.attn_type = attn_type @@ -372,18 +383,24 @@ class AscendAttentionBackendImpl(AttentionImpl): self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.key_cache = None self.value_cache = None - self.is_kv_producer = self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer + self.is_kv_producer = ( + self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer + ) def process_weights_after_loading(self, act_dtype: torch.dtype): super().process_weights_after_loading(act_dtype) if flashcomm2_oshard_manager.flashcomm2_oshard_enable(): flashcomm2_oshard_manager.post_process_after_loading() - def full_graph_fia(self, query: torch.Tensor, key: torch.Tensor, - value: torch.Tensor, attn_metadata: AscendMetadata, - output: torch.Tensor) -> torch.Tensor: - key, value, block_size, block_table, actual_seq_lengths_kv \ - = self._get_fia_params(key, value, attn_metadata) + def full_graph_fia( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AscendMetadata, + output: torch.Tensor, + ) -> torch.Tensor: + key, value, block_size, block_table, actual_seq_lengths_kv = self._get_fia_params(key, value, attn_metadata) num_tokens = attn_metadata.actual_seq_lengths_q[-1] forward_context = get_forward_context() @@ -427,12 +444,22 @@ class AscendAttentionBackendImpl(AttentionImpl): event.reset(stream) graph_params.events[num_tokens].append(event) graph_params.attn_params[num_tokens].append( - (weak_ref_tensors(query), weak_ref_tensors(key), - weak_ref_tensors(value), weak_ref_tensors(block_table), - weak_ref_tensors(attn_metadata.attn_mask), block_size, - actual_seq_lengths_kv, actual_seq_lengths_q, self.num_kv_heads, - self.num_heads, self.scale, weak_ref_tensors(output), - weak_ref_tensors(softmax_lse))) + ( + weak_ref_tensors(query), + weak_ref_tensors(key), + weak_ref_tensors(value), + weak_ref_tensors(block_table), + weak_ref_tensors(attn_metadata.attn_mask), + block_size, + actual_seq_lengths_kv, + actual_seq_lengths_q, + self.num_kv_heads, + self.num_heads, + self.scale, + weak_ref_tensors(output), + weak_ref_tensors(softmax_lse), + ) + ) torch.npu.graph_task_group_begin(stream) torch_npu.npu_fused_infer_attention_score.out( @@ -463,7 +490,7 @@ class AscendAttentionBackendImpl(AttentionImpl): self, query: torch.Tensor, attn_metadata: AscendMetadata, - output: Optional[torch.Tensor] = None, + output: torch.Tensor | None = None, ): graph_params = get_graph_params() forward_context: ForwardContext = get_forward_context() @@ -481,7 +508,8 @@ class AscendAttentionBackendImpl(AttentionImpl): scale_value=self.scale, block_table=attn_metadata.block_tables, context_lens=attn_metadata.seq_lens, - out=output) + out=output, + ) update_graph_params_workspaces(num_tokens, workspace) # Handle graph capturing mode @@ -491,17 +519,19 @@ class AscendAttentionBackendImpl(AttentionImpl): event.wait(stream) event.reset(stream) graph_params.events[num_tokens].append(event) - graph_params.attn_params[num_tokens].append(( - weak_ref_tensors(query), - weak_ref_tensors(self.key_cache), - weak_ref_tensors(self.value_cache), - self.num_kv_heads, - self.num_heads, - self.scale, - attn_metadata.block_tables, - attn_metadata.seq_lens, - weak_ref_tensors(output), - )) + graph_params.attn_params[num_tokens].append( + ( + weak_ref_tensors(query), + weak_ref_tensors(self.key_cache), + weak_ref_tensors(self.value_cache), + self.num_kv_heads, + self.num_heads, + self.scale, + attn_metadata.block_tables, + attn_metadata.seq_lens, + weak_ref_tensors(output), + ) + ) torch.npu.graph_task_group_begin(stream) torch_npu._npu_paged_attention( @@ -514,53 +544,54 @@ class AscendAttentionBackendImpl(AttentionImpl): block_table=attn_metadata.block_tables, context_lens=attn_metadata.seq_lens, out=output, - workspace=workspace) + workspace=workspace, + ) handle = torch.npu.graph_task_group_end(stream) graph_params.handles[num_tokens].append(handle) return output - def _get_fia_params(self, key: torch.Tensor, value: torch.Tensor, - attn_metadata: AscendMetadata): - + def _get_fia_params(self, key: torch.Tensor, value: torch.Tensor, attn_metadata: AscendMetadata): if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: block_size = 128 block_table = None actual_seq_lengths_kv = attn_metadata.actual_seq_lengths_q if self.attn_type == AttentionType.ENCODER_DECODER: - actual_seq_lengths_kv = torch.cumsum(attn_metadata.seq_lens, - dim=0).tolist() - elif attn_metadata.attn_state == \ - AscendAttentionState.PrefillCacheHit: + actual_seq_lengths_kv = torch.cumsum(attn_metadata.seq_lens, dim=0).tolist() + elif attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit: batch_size = attn_metadata.seq_lens.shape[0] block_table = attn_metadata.block_tables[:batch_size, :] num_block, block_size, _, _ = self.key_cache.shape # type: ignore key = self.key_cache.view( # type: ignore - num_block, block_size, -1) + num_block, block_size, -1 + ) value = self.value_cache.view( # type: ignore - num_block, block_size, -1) + num_block, block_size, -1 + ) actual_seq_lengths_kv = attn_metadata.seq_lens_list elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: num_block, block_size, _, _ = self.key_cache.shape # type: ignore key = self.key_cache.view( # type: ignore - num_block, block_size, -1) + num_block, block_size, -1 + ) value = self.value_cache.view( # type: ignore - num_block, block_size, -1) + num_block, block_size, -1 + ) block_table = attn_metadata.block_tables actual_seq_lengths_kv = attn_metadata.seq_lens_list # chunked prefill. else: num_block, block_size, _, _ = self.key_cache.shape # type: ignore key = self.key_cache.view( # type: ignore - num_block, block_size, -1) + num_block, block_size, -1 + ) value = self.value_cache.view( # type: ignore - num_block, block_size, -1) + num_block, block_size, -1 + ) block_table = attn_metadata.block_tables actual_seq_lengths_kv = attn_metadata.seq_lens_list return key, value, block_size, block_table, actual_seq_lengths_kv - def _forward_fia_slidingwindow(self, query: torch.Tensor, - attn_metadata: AscendMetadata, - output: torch.Tensor): + def _forward_fia_slidingwindow(self, query: torch.Tensor, attn_metadata: AscendMetadata, output: torch.Tensor): batch_size = attn_metadata.seq_lens.shape[0] block_size = 128 query = query.view(batch_size, 1, self.num_heads * self.head_size) @@ -583,34 +614,41 @@ class AscendAttentionBackendImpl(AttentionImpl): scale=self.scale, block_table=attn_metadata.block_tables, actual_seq_lengths=[1] * len(attn_metadata.seq_lens), - actual_seq_lengths_kv=attn_metadata.seq_lens) + actual_seq_lengths_kv=attn_metadata.seq_lens, + ) output = output.view(batch_size, self.num_heads, self.head_size) return output - def forward_fused_infer_attention(self, query: torch.Tensor, - key: torch.Tensor, value: torch.Tensor, - attn_metadata: AscendMetadata, - output: torch.Tensor): + def forward_fused_infer_attention( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AscendMetadata, + output: torch.Tensor, + ): forward_context: ForwardContext = get_forward_context() # we inherit ForwardContext in model runner v2, when enable model # runner v2, there is not capturing attribute in forward_context, # just use getattr to avoid attribute error. if getattr(forward_context, "capturing", False): - attn_output, num_tokens = self.full_graph_fia( - query, key, value, attn_metadata, output) + attn_output, num_tokens = self.full_graph_fia(query, key, value, attn_metadata, output) output[:num_tokens] = attn_output[:num_tokens] return output - if (attn_metadata.attn_state == AscendAttentionState.DecodeOnly - and self.sliding_window is not None - and attn_metadata.seq_lens.shape[0] == query.size(0)): - return self._forward_fia_slidingwindow(query, attn_metadata, - output) - key, value, block_size, block_table, actual_seq_lengths_kv \ - = self._get_fia_params(key, value, attn_metadata) + if ( + attn_metadata.attn_state == AscendAttentionState.DecodeOnly + and self.sliding_window is not None + and attn_metadata.seq_lens.shape[0] == query.size(0) + ): + return self._forward_fia_slidingwindow(query, attn_metadata, output) + key, value, block_size, block_table, actual_seq_lengths_kv = self._get_fia_params(key, value, attn_metadata) num_tokens = attn_metadata.actual_seq_lengths_q[-1] query = query[:num_tokens] - if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache and self.attn_type != AttentionType.ENCODER_DECODER: + if ( + attn_metadata.attn_state == AscendAttentionState.PrefillNoCache + and self.attn_type != AttentionType.ENCODER_DECODER + ): key = key[:num_tokens] value = value[:num_tokens] # Get workspace from cache or calculate it if not present. @@ -630,8 +668,7 @@ class AscendAttentionBackendImpl(AttentionImpl): sparse_mode=3, ) - attn_output = attn_output.view(num_tokens, self.num_heads, - self.head_size) + attn_output = attn_output.view(num_tokens, self.num_heads, self.head_size) output[:num_tokens] = attn_output[:num_tokens] return output @@ -639,26 +676,32 @@ class AscendAttentionBackendImpl(AttentionImpl): self, query: torch.Tensor, attn_metadata: AscendMetadata, - output: Optional[torch.Tensor] = None, + output: torch.Tensor | None = None, ) -> torch.Tensor: forward_context: ForwardContext = get_forward_context() if forward_context.capturing: return self.full_graph_pa(query, attn_metadata, output) - torch_npu._npu_paged_attention(query=query, - key_cache=self.key_cache, - value_cache=self.value_cache, - num_kv_heads=self.num_kv_heads, - num_heads=self.num_heads, - scale_value=self.scale, - block_table=attn_metadata.block_tables, - context_lens=attn_metadata.seq_lens, - out=output) + torch_npu._npu_paged_attention( + query=query, + key_cache=self.key_cache, + value_cache=self.value_cache, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale_value=self.scale, + block_table=attn_metadata.block_tables, + context_lens=attn_metadata.seq_lens, + out=output, + ) return output - def _forward_encoder_attention(self, query: torch.Tensor, - key: torch.Tensor, value: torch.Tensor, - attn_metadata: AscendMetadata, - _: torch.Tensor) -> torch.Tensor: + def _forward_encoder_attention( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AscendMetadata, + _: torch.Tensor, + ) -> torch.Tensor: assert attn_metadata is not None if attn_metadata.causal: @@ -692,26 +735,23 @@ class AscendAttentionBackendImpl(AttentionImpl): self, key: torch.Tensor, value: torch.Tensor, - kv_cache: Tuple[torch.Tensor], + kv_cache: tuple[torch.Tensor], attn_metadata: AscendMetadata, ): - if len(kv_cache) > 1: if self.is_kv_producer: attn_metadata.reshape_cache_event = torch.npu.Event() if self.key_cache is None: self.key_cache, self.value_cache = kv_cache[0], kv_cache[1] slots = attn_metadata.slot_mapping - encoder_decoder = (self.attn_type == AttentionType.ENCODER_DECODER) + encoder_decoder = self.attn_type == AttentionType.ENCODER_DECODER DeviceOperator.reshape_and_cache( - key=key[:attn_metadata.num_actual_tokens] - if not encoder_decoder else key, - value=value[:attn_metadata.num_actual_tokens] - if not encoder_decoder else value, + key=key[: attn_metadata.num_actual_tokens] if not encoder_decoder else key, + value=value[: attn_metadata.num_actual_tokens] if not encoder_decoder else value, key_cache=self.key_cache, value_cache=self.value_cache, - slot_mapping=slots[:attn_metadata.num_actual_tokens] - if not encoder_decoder else slots) + slot_mapping=slots[: attn_metadata.num_actual_tokens] if not encoder_decoder else slots, + ) if self.is_kv_producer: attn_metadata.reshape_cache_event.record() return key, value @@ -721,18 +761,19 @@ class AscendAttentionBackendImpl(AttentionImpl): query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - kv_cache: Tuple[torch.Tensor], + kv_cache: tuple[torch.Tensor], attn_metadata: AscendMetadata, output: torch.Tensor, ): num_tokens = query.shape[0] - if (attn_metadata.attn_state == AscendAttentionState.DecodeOnly - and using_paged_attention(num_tokens, self.vllm_config) - and self.sliding_window is None): + if ( + attn_metadata.attn_state == AscendAttentionState.DecodeOnly + and using_paged_attention(num_tokens, self.vllm_config) + and self.sliding_window is None + ): output = self.forward_paged_attention(query, attn_metadata, output) else: - output = self.forward_fused_infer_attention( - query, key, value, attn_metadata, output) + output = self.forward_fused_infer_attention(query, key, value, attn_metadata, output) return output @@ -742,11 +783,11 @@ class AscendAttentionBackendImpl(AttentionImpl): query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - kv_cache: Tuple[torch.Tensor], + kv_cache: tuple[torch.Tensor], attn_metadata: AscendMetadata, - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - output_block_scale: Optional[torch.Tensor] = None, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, ) -> torch.Tensor: """Forward pass with Ascend attention. Args: @@ -762,23 +803,18 @@ class AscendAttentionBackendImpl(AttentionImpl): assert output is not None, "Output tensor must be provided." if output_scale is not None or output_block_scale is not None: - raise NotImplementedError( - "fused output quantization is not yet supported" - " for AscendAttentionBackendImpl") + raise NotImplementedError("fused output quantization is not yet supported for AscendAttentionBackendImpl") assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 num_tokens = query.shape[0] if attn_metadata is None: return output.fill_(0) if key is not None and value is not None: - key, value = self.reshape_and_cache(key, value, kv_cache, - attn_metadata) + key, value = self.reshape_and_cache(key, value, kv_cache, attn_metadata) # pooling model branch if attn_metadata.model_runner_type == "pooling": - attn_output = self._forward_encoder_attention( - query, key, value, attn_metadata, output) + attn_output = self._forward_encoder_attention(query, key, value, attn_metadata, output) output[:num_tokens] = attn_output[:num_tokens] return output - output = self.forward_impl(query, key, value, kv_cache, attn_metadata, - output) + output = self.forward_impl(query, key, value, kv_cache, attn_metadata, output) return output diff --git a/vllm_ascend/attention/context_parallel/attention_cp.py b/vllm_ascend/attention/context_parallel/attention_cp.py index 1ccc6b78..9cbfebcd 100644 --- a/vllm_ascend/attention/context_parallel/attention_cp.py +++ b/vllm_ascend/attention/context_parallel/attention_cp.py @@ -15,35 +15,43 @@ # This file is a part of the vllm-ascend project. # -from typing import ClassVar, List, Optional, Tuple +from typing import ClassVar import numpy as np import torch import torch.distributed as dist import torch_npu from vllm.config import VllmConfig -from vllm.distributed import (get_dcp_group, - get_decode_context_model_parallel_rank, - get_decode_context_model_parallel_world_size, - get_pcp_group) +from vllm.distributed import ( + get_dcp_group, + get_decode_context_model_parallel_rank, + get_decode_context_model_parallel_world_size, + get_pcp_group, +) from vllm.forward_context import ForwardContext, get_forward_context from vllm.v1.kv_cache_interface import AttentionSpec -from vllm_ascend.attention.attention_v1 import (AscendAttentionBackendImpl, - AscendAttentionMetadataBuilder, - AscendMetadata) +from vllm_ascend.attention.attention_v1 import ( + AscendAttentionBackendImpl, + AscendAttentionMetadataBuilder, + AscendMetadata, +) from vllm_ascend.attention.context_parallel.common_cp import ( - AscendMetadataForDecode, AscendMetadataForPrefill, AscendPCPMetadata, - _npu_attention_update, _process_attn_out_lse) -from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, - filter_chunked_req_indices, - split_decodes_and_prefills) -from vllm_ascend.compilation.acl_graph import (get_graph_params, - update_graph_params_workspaces) -from vllm_ascend.utils import (cp_chunkedprefill_comm_stream, vllm_version_is, - weak_ref_tensors) + AscendMetadataForDecode, + AscendMetadataForPrefill, + AscendPCPMetadata, + _npu_attention_update, + _process_attn_out_lse, +) +from vllm_ascend.attention.utils import ( + AscendCommonAttentionMetadata, + filter_chunked_req_indices, + split_decodes_and_prefills, +) +from vllm_ascend.compilation.acl_graph import get_graph_params, update_graph_params_workspaces +from vllm_ascend.utils import cp_chunkedprefill_comm_stream, vllm_version_is, weak_ref_tensors -if vllm_version_is('0.13.0'): +if vllm_version_is("0.13.0"): from vllm.v1.attention.backends.utils import AttentionCGSupport else: from vllm.v1.attention.backend import AttentionCGSupport @@ -55,6 +63,7 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder): Extends AscendAttentionMetadataBuilder with PCP/DCP metadata handling. """ + # Does this backend/builder reorder the batch? # If not, set this to None. Otherwise set it to the query # length that will be pulled into the front of the batch. @@ -69,15 +78,12 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder): ): super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.batch_seq_mask_buf = torch.empty( - vllm_config.scheduler_config.max_num_batched_tokens, - dtype=torch.uint8, - device=device) + vllm_config.scheduler_config.max_num_batched_tokens, dtype=torch.uint8, device=device + ) self.pcp_size = get_pcp_group().world_size - self.pcp_rank = get_pcp_group( - ).rank_in_group if self.pcp_size > 1 else 0 + self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0 self.dcp_size = get_decode_context_model_parallel_world_size() - self.dcp_rank = get_decode_context_model_parallel_rank( - ) if self.dcp_size > 1 else 0 + self.dcp_rank = get_decode_context_model_parallel_rank() if self.dcp_size > 1 else 0 @classmethod def get_cudagraph_support( @@ -89,7 +95,7 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder): # @override omitted only because of mypy limitation due to type variable. return AttentionCGSupport.ALWAYS - def _get_chunked_req_mask(self, local_context_lens_allranks) -> List[bool]: + def _get_chunked_req_mask(self, local_context_lens_allranks) -> list[bool]: """ given 4-d list [req][pcp][dcp], return: 1. if each req has any chunk (list[bool]) @@ -97,9 +103,7 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder): assert local_context_lens_allranks is not None if len(local_context_lens_allranks) == 0: return [] - chunked_req_mask = [(req.sum() > 0).item() - for req in local_context_lens_allranks - if req is not None] + chunked_req_mask = [(req.sum() > 0).item() for req in local_context_lens_allranks if req is not None] return chunked_req_mask def build( @@ -110,12 +114,11 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder): ): num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens - query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[: - num_reqs - + 1] + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[: num_reqs + 1] - num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ - split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold) + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = split_decodes_and_prefills( + common_attn_metadata, decode_threshold=self.decode_threshold + ) assert num_decodes + num_prefills == num_reqs assert num_decode_tokens + num_prefill_tokens == num_actual_tokens @@ -128,22 +131,18 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder): if num_actual_tokens_pcp_padded is None: num_actual_tokens_pcp_padded = num_actual_tokens - slot_mapping = common_attn_metadata.slot_mapping[: - num_actual_tokens_pcp_padded] - attn_mask = self.attn_mask_builder.get_attention_mask( - self.model_config) + slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens_pcp_padded] + attn_mask = self.attn_mask_builder.get_attention_mask(self.model_config) attn_state = common_attn_metadata.attn_state - num_computed_tokens_cpu = (seq_lens - query_lens) + num_computed_tokens_cpu = seq_lens - query_lens - query_start_loc = query_start_loc_cpu.to(self.device, - non_blocking=True) + query_start_loc = query_start_loc_cpu.to(self.device, non_blocking=True) common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata prefill_metadata = None decode_metadata = None if common_long_seq_metadata is None: - raise AssertionError( - "common_long_seq_metadata should not be None.") + raise AssertionError("common_long_seq_metadata should not be None.") num_computed_tokens_of_pcp_dcp = common_long_seq_metadata.num_computed_tokens_of_pcp_dcp assert num_computed_tokens_of_pcp_dcp is not None chunked_context_metadata = None @@ -153,98 +152,86 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder): max_context_len_cpu = context_lens_cpu.max().item() pcp_size = get_pcp_group().world_size if self.chunked_prefill_enabled and max_context_len_cpu > 0: - local_context_lens_allranks = torch.tensor( - num_computed_tokens_of_pcp_dcp)[num_decodes:num_reqs].to( - self.device).to(dtype=torch.int32) - local_chunked_kv_lens_rank = local_context_lens_allranks[:, - self. - pcp_rank, - self. - dcp_rank] - actual_seq_lengths_kv = torch.cumsum( - local_chunked_kv_lens_rank, dim=0).tolist() + local_context_lens_allranks = ( + torch.tensor(num_computed_tokens_of_pcp_dcp)[num_decodes:num_reqs] + .to(self.device) + .to(dtype=torch.int32) + ) + local_chunked_kv_lens_rank = local_context_lens_allranks[:, self.pcp_rank, self.dcp_rank] + actual_seq_lengths_kv = torch.cumsum(local_chunked_kv_lens_rank, dim=0).tolist() local_total_toks = local_chunked_kv_lens_rank.sum() - chunked_req_mask = self._get_chunked_req_mask( - local_context_lens_allranks) + chunked_req_mask = self._get_chunked_req_mask(local_context_lens_allranks) local_chunk_starts = torch.zeros( - (len(local_context_lens_allranks)), - dtype=torch.int32, - device=self.device) + (len(local_context_lens_allranks)), dtype=torch.int32, device=self.device + ) cp_kv_recover_idx_for_chunk = common_long_seq_metadata.cp_kv_recover_idx_for_chunk - kv_inverse_idx_for_chunk = torch.argsort( - cp_kv_recover_idx_for_chunk.to(torch.float32) - ) if cp_kv_recover_idx_for_chunk is not None else None + kv_inverse_idx_for_chunk = ( + torch.argsort(cp_kv_recover_idx_for_chunk.to(torch.float32)) + if cp_kv_recover_idx_for_chunk is not None + else None + ) - batch_chunk_seq_mask = ( - local_context_lens_allranks[:, self.pcp_rank, - self.dcp_rank] == 0) + batch_chunk_seq_mask = local_context_lens_allranks[:, self.pcp_rank, self.dcp_rank] == 0 batch_chunk_seq_mask = torch.repeat_interleave( - batch_chunk_seq_mask, - repeats=(query_lens * self.pcp_size).to(self.device)) - chunk_seq_mask_filtered_indices = filter_chunked_req_indices( - query_lens, chunked_req_mask).to(self.device) - chunked_context_metadata = \ - AscendMetadataForPrefill.ChunkedContextMetadata( - actual_chunk_seq_lengths=torch.cumsum(query_lens * pcp_size, dim=0), - actual_seq_lengths_kv=actual_seq_lengths_kv, - chunked_req_mask=chunked_req_mask, - starts=local_chunk_starts, - local_context_lens_allranks=local_context_lens_allranks, - cp_kv_recover_idx_for_chunk=cp_kv_recover_idx_for_chunk, - kv_inverse_idx_for_chunk=kv_inverse_idx_for_chunk, - batch_chunk_seq_mask=batch_chunk_seq_mask, - chunk_seq_mask_filtered_indices=chunk_seq_mask_filtered_indices, - local_total_toks=local_total_toks.item() - ) + batch_chunk_seq_mask, repeats=(query_lens * self.pcp_size).to(self.device) + ) + chunk_seq_mask_filtered_indices = filter_chunked_req_indices(query_lens, chunked_req_mask).to( + self.device + ) + chunked_context_metadata = AscendMetadataForPrefill.ChunkedContextMetadata( + actual_chunk_seq_lengths=torch.cumsum(query_lens * pcp_size, dim=0), + actual_seq_lengths_kv=actual_seq_lengths_kv, + chunked_req_mask=chunked_req_mask, + starts=local_chunk_starts, + local_context_lens_allranks=local_context_lens_allranks, + cp_kv_recover_idx_for_chunk=cp_kv_recover_idx_for_chunk, + kv_inverse_idx_for_chunk=kv_inverse_idx_for_chunk, + batch_chunk_seq_mask=batch_chunk_seq_mask, + chunk_seq_mask_filtered_indices=chunk_seq_mask_filtered_indices, + local_total_toks=local_total_toks.item(), + ) attn_mask_seqlens = common_long_seq_metadata.attn_mask_seqlens head_attn_nomask_seqlens = common_long_seq_metadata.head_attn_nomask_seqlens tail_attn_nomask_seqlens = common_long_seq_metadata.tail_attn_nomask_seqlens if pcp_size > 1: - attn_mask_seqlens = torch.cumsum(attn_mask_seqlens[0], - dim=0).tolist() - head_attn_nomask_seqlens = torch.cumsum( - head_attn_nomask_seqlens[1], dim=0).tolist() - tail_attn_nomask_seqlens = torch.cumsum( - tail_attn_nomask_seqlens[1], dim=0).tolist() + attn_mask_seqlens = torch.cumsum(attn_mask_seqlens[0], dim=0).tolist() + head_attn_nomask_seqlens = torch.cumsum(head_attn_nomask_seqlens[1], dim=0).tolist() + tail_attn_nomask_seqlens = torch.cumsum(tail_attn_nomask_seqlens[1], dim=0).tolist() pcp_metadata = AscendPCPMetadata( q_head_idx=common_long_seq_metadata.q_head_idx_tensor, q_tail_idx=common_long_seq_metadata.q_tail_idx_tensor, - kv_with_q_head_nomask_idx=common_long_seq_metadata. - kv_with_q_head_nomask_idx_tensor, - kv_with_q_head_mask_idx=common_long_seq_metadata. - kv_with_q_head_mask_idx_tensor, - kv_with_q_tail_nomask_idx=common_long_seq_metadata. - kv_with_q_tail_nomask_idx_tensor, - kv_with_q_tail_mask_idx=common_long_seq_metadata. - kv_with_q_tail_mask_idx_tensor, + kv_with_q_head_nomask_idx=common_long_seq_metadata.kv_with_q_head_nomask_idx_tensor, + kv_with_q_head_mask_idx=common_long_seq_metadata.kv_with_q_head_mask_idx_tensor, + kv_with_q_tail_nomask_idx=common_long_seq_metadata.kv_with_q_tail_nomask_idx_tensor, + kv_with_q_tail_mask_idx=common_long_seq_metadata.kv_with_q_tail_mask_idx_tensor, attn_mask_seqlens=attn_mask_seqlens, head_attn_nomask_seqlens=head_attn_nomask_seqlens, tail_attn_nomask_seqlens=tail_attn_nomask_seqlens, q_full_idx=common_long_seq_metadata.q_full_idx, - pcp_allgather_restore_idx=common_long_seq_metadata. - pcp_allgather_restore_idx) + pcp_allgather_restore_idx=common_long_seq_metadata.pcp_allgather_restore_idx, + ) prefill_metadata = AscendMetadataForPrefill( pcp_metadata=pcp_metadata, chunked_context=chunked_context_metadata, block_tables=block_table[num_decodes:], - actual_seq_lengths_q=torch.cumsum(query_lens, dim=0)) + actual_seq_lengths_q=torch.cumsum(query_lens, dim=0), + ) if num_decodes > 0: - num_computed_tokens_array = np.array( - num_computed_tokens_of_pcp_dcp) + num_computed_tokens_array = np.array(num_computed_tokens_of_pcp_dcp) num_computed_tokens_array = num_computed_tokens_array[:num_decodes] - batch_seq_mask = (num_computed_tokens_array[:, self.pcp_rank, - self.dcp_rank] == 0) + batch_seq_mask = num_computed_tokens_array[:, self.pcp_rank, self.dcp_rank] == 0 # TODO: numpy array mode of the shared memory is used to improve performance - self.batch_seq_mask_buf[:batch_seq_mask.shape[0]].copy_( - torch.from_numpy(batch_seq_mask), non_blocking=True) + self.batch_seq_mask_buf[: batch_seq_mask.shape[0]].copy_( + torch.from_numpy(batch_seq_mask), non_blocking=True + ) decode_metadata = AscendMetadataForDecode( num_computed_tokens_of_pcp_dcp=num_computed_tokens_array, - batch_seq_mask=self.batch_seq_mask_buf[:batch_seq_mask. - shape[0]], - block_tables=block_table[:num_decodes]) + batch_seq_mask=self.batch_seq_mask_buf[: batch_seq_mask.shape[0]], + block_tables=block_table[:num_decodes], + ) attn_metadata = AscendMetadata( num_actual_tokens=num_actual_tokens, @@ -262,52 +249,60 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder): num_prefills=num_prefills, num_decodes=num_decodes, prefill=prefill_metadata, - decode_meta=decode_metadata) + decode_meta=decode_metadata, + ) return attn_metadata class AscendAttentionCPImpl(AscendAttentionBackendImpl): - def __init__( self, num_heads: int, head_size: int, scale: float, num_kv_heads: int, - alibi_slopes: Optional[List[float]], - sliding_window: Optional[int], + alibi_slopes: list[float] | None, + sliding_window: int | None, kv_cache_dtype: str, - logits_soft_cap: Optional[float], + logits_soft_cap: float | None, attn_type: str, - kv_sharing_target_layer_name: Optional[str], + kv_sharing_target_layer_name: str | None, **kwargs, ) -> None: - super().__init__(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window, kv_cache_dtype, - logits_soft_cap, attn_type, - kv_sharing_target_layer_name, **kwargs) + super().__init__( + num_heads, + head_size, + scale, + num_kv_heads, + alibi_slopes, + sliding_window, + kv_cache_dtype, + logits_soft_cap, + attn_type, + kv_sharing_target_layer_name, + **kwargs, + ) self.pcp_size = get_pcp_group().world_size - self.pcp_rank = get_pcp_group( - ).rank_in_group if self.pcp_size > 1 else 0 - self.pcp_group = get_pcp_group( - ).device_group if self.pcp_size > 1 else None + self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0 + self.pcp_group = get_pcp_group().device_group if self.pcp_size > 1 else None self.dcp_size = get_decode_context_model_parallel_world_size() - self.dcp_rank = get_decode_context_model_parallel_rank( - ) if self.dcp_size > 1 else 0 - self.dcp_group = get_dcp_group( - ).device_group if self.dcp_size > 1 else None + self.dcp_rank = get_decode_context_model_parallel_rank() if self.dcp_size > 1 else 0 + self.dcp_group = get_dcp_group().device_group if self.dcp_size > 1 else None - def _attention_with_nomask_and_mask(self, q: torch.Tensor, - q_seqlens: List[int], - k_nomask: torch.Tensor, - v_nomask: torch.Tensor, - kv_seqlens_nomask: List[int], - k_mask: torch.Tensor, - v_mask: torch.Tensor, - kv_seqlens_mask: List[int], - mask: torch.Tensor, - attn_metadata) -> torch.Tensor: + def _attention_with_nomask_and_mask( + self, + q: torch.Tensor, + q_seqlens: list[int], + k_nomask: torch.Tensor, + v_nomask: torch.Tensor, + kv_seqlens_nomask: list[int], + k_mask: torch.Tensor, + v_mask: torch.Tensor, + kv_seqlens_mask: list[int], + mask: torch.Tensor, + attn_metadata, + ) -> torch.Tensor: # nomask Attention if k_nomask is not None: attn_out_nomask, attn_lse_nomask = torch.ops.npu.npu_fused_infer_attention_score( @@ -324,7 +319,8 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): antiquant_scale=None, softmax_lse_flag=True, actual_seq_lengths_kv=kv_seqlens_nomask, - actual_seq_lengths=q_seqlens) + actual_seq_lengths=q_seqlens, + ) # mask Attention attn_out_mask, attn_lse_mask = torch.ops.npu.npu_fused_infer_attention_score( @@ -341,33 +337,29 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): antiquant_scale=None, softmax_lse_flag=True, actual_seq_lengths_kv=kv_seqlens_mask, - actual_seq_lengths=q_seqlens) + actual_seq_lengths=q_seqlens, + ) # update output = attn_out_mask attn_lse = attn_lse_mask if k_nomask is not None: if attn_metadata.prefill is not None and attn_metadata.prefill.chunked_context is None: - output = self._npu_attn_out_lse_update(attn_lse_mask, - attn_lse_nomask, - attn_out_mask, - attn_out_nomask) + output = self._npu_attn_out_lse_update(attn_lse_mask, attn_lse_nomask, attn_out_mask, attn_out_nomask) attn_lse = None else: output, attn_lse = self._update_out_and_lse( torch.stack([attn_out_nomask, attn_out_mask], dim=0), - torch.stack([attn_lse_nomask, attn_lse_mask], dim=0)) + torch.stack([attn_lse_nomask, attn_lse_mask], dim=0), + ) return output, attn_lse - def _npu_attn_out_lse_update(self, attn_lse_mask, attn_lse_nomask, - attn_out_mask, attn_out_nomask): + def _npu_attn_out_lse_update(self, attn_lse_mask, attn_lse_nomask, attn_out_mask, attn_out_nomask): T = attn_out_mask.shape[0] N = attn_out_mask.shape[1] D = attn_out_mask.shape[2] - attn_out_mask, attn_lse_mask = self._out_lse_reshape( - attn_out_mask, attn_lse_mask) - attn_out_nomask, attn_lse_nomask = self._out_lse_reshape( - attn_out_nomask, attn_lse_nomask) + attn_out_mask, attn_lse_mask = self._out_lse_reshape(attn_out_mask, attn_lse_mask) + attn_out_nomask, attn_lse_nomask = self._out_lse_reshape(attn_out_nomask, attn_lse_nomask) attn_out_mask = attn_out_mask.to(torch.float32) attn_out_nomask = attn_out_nomask.to(torch.float32) attn_lse_mask = attn_lse_mask.to(torch.float32) @@ -375,22 +367,17 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): attn_output = [attn_out_nomask, attn_out_mask] attn_lse = [attn_lse_nomask, attn_lse_mask] update_type = 0 - output, _ = torch_npu.npu_attention_update(attn_lse, attn_output, - update_type) + output, _ = torch_npu.npu_attention_update(attn_lse, attn_output, update_type) output = output.view(T, N, D) return output - def _forward_prefill_cp(self, query: torch.Tensor, key: torch.Tensor, - value: torch.Tensor, - attn_metadata: AscendMetadata) -> torch.Tensor: + def _forward_prefill_cp( + self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_metadata: AscendMetadata + ) -> torch.Tensor: + data_head, data_tail = self._forward_prefill_cp_pre(query, key, value, attn_metadata) - data_head, data_tail = self._forward_prefill_cp_pre( - query, key, value, attn_metadata) - - output_head, lse_head = self._forward_prefill_cp_attn( - data_head, True, attn_metadata) - output_tail, lse_tail = self._forward_prefill_cp_attn( - data_tail, False, attn_metadata) + output_head, lse_head = self._forward_prefill_cp_attn(data_head, True, attn_metadata) + output_tail, lse_tail = self._forward_prefill_cp_attn(data_tail, False, attn_metadata) output, attn_lse = self._forward_prefill_cp_post( [output_head, output_tail], @@ -399,9 +386,9 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): ) return output, attn_lse - def _forward_prefill_cp_pre(self, query: torch.Tensor, key: torch.Tensor, - value: torch.Tensor, - attn_metadata: AscendMetadata) -> torch.Tensor: + def _forward_prefill_cp_pre( + self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_metadata: AscendMetadata + ) -> torch.Tensor: assert attn_metadata is not None assert attn_metadata.prefill is not None assert attn_metadata.prefill.pcp_metadata is not None @@ -414,41 +401,46 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): kv_with_q_tail_mask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_tail_mask_idx q_head = torch.index_select(query, 0, q_head_idx) q_tail = torch.index_select(query, 0, q_tail_idx) - k_head_nomask=torch.index_select(key, 0, kv_with_q_head_nomask_idx) \ - if self.pcp_rank > 0 else None - v_head_nomask=torch.index_select(value, 0, kv_with_q_head_nomask_idx) \ - if self.pcp_rank > 0 else None + k_head_nomask = torch.index_select(key, 0, kv_with_q_head_nomask_idx) if self.pcp_rank > 0 else None + v_head_nomask = torch.index_select(value, 0, kv_with_q_head_nomask_idx) if self.pcp_rank > 0 else None k_head_mask = torch.index_select(key, 0, kv_with_q_head_mask_idx) v_head_mask = torch.index_select(value, 0, kv_with_q_head_mask_idx) k_tail_nomask = torch.index_select(key, 0, kv_with_q_tail_nomask_idx) v_tail_nomask = torch.index_select(value, 0, kv_with_q_tail_nomask_idx) k_tail_mask = torch.index_select(key, 0, kv_with_q_tail_mask_idx) v_tail_mask = torch.index_select(value, 0, kv_with_q_tail_mask_idx) - return { - "q": q_head, - "k_nomask": k_head_nomask, - "v_nomask": v_head_nomask, - "k_mask": k_head_mask, - "v_mask": v_head_mask, - }, { - "q": q_tail, - "k_nomask": k_tail_nomask, - "v_nomask": v_tail_nomask, - "k_mask": k_tail_mask, - "v_mask": v_tail_mask, - }, + return ( + { + "q": q_head, + "k_nomask": k_head_nomask, + "v_nomask": v_head_nomask, + "k_mask": k_head_mask, + "v_mask": v_head_mask, + }, + { + "q": q_tail, + "k_nomask": k_tail_nomask, + "v_nomask": v_tail_nomask, + "k_mask": k_tail_mask, + "v_mask": v_tail_mask, + }, + ) def _forward_prefill_cp_attn(self, data, is_head, attn_metadata): attn_mask_seqlens = attn_metadata.prefill.pcp_metadata.attn_mask_seqlens - nomask_seqlens = attn_metadata.prefill.pcp_metadata.head_attn_nomask_seqlens \ - if is_head else attn_metadata.prefill.pcp_metadata.tail_attn_nomask_seqlens + nomask_seqlens = ( + attn_metadata.prefill.pcp_metadata.head_attn_nomask_seqlens + if is_head + else attn_metadata.prefill.pcp_metadata.tail_attn_nomask_seqlens + ) output, lse = self._attention_with_nomask_and_mask( **data, q_seqlens=attn_mask_seqlens, kv_seqlens_nomask=nomask_seqlens, kv_seqlens_mask=attn_mask_seqlens, mask=attn_metadata.attn_mask, - attn_metadata=attn_metadata) + attn_metadata=attn_metadata, + ) return output, lse def _forward_prefill_cp_post(self, outputs, lses, attn_metadata): @@ -456,20 +448,15 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): output = torch.index_select(torch.cat(outputs, dim=0), 0, q_full_idx) attn_lse = None if attn_metadata.prefill is not None and attn_metadata.prefill.chunked_context is not None: - attn_lse = torch.index_select(torch.cat(lses, dim=0), 0, - q_full_idx) + attn_lse = torch.index_select(torch.cat(lses, dim=0), 0, q_full_idx) return output, attn_lse - def _out_lse_reshape(self, attn_out: torch.Tensor, - attn_lse: torch.Tensor) -> torch.Tensor: - attn_out = attn_out.contiguous().view( - attn_out.shape[0] * attn_out.shape[1], attn_out.shape[2]) - attn_lse = attn_lse.contiguous().view( - attn_lse.shape[0] * attn_lse.shape[1] * attn_lse.shape[2]) + def _out_lse_reshape(self, attn_out: torch.Tensor, attn_lse: torch.Tensor) -> torch.Tensor: + attn_out = attn_out.contiguous().view(attn_out.shape[0] * attn_out.shape[1], attn_out.shape[2]) + attn_lse = attn_lse.contiguous().view(attn_lse.shape[0] * attn_lse.shape[1] * attn_lse.shape[2]) return attn_out, attn_lse - def _forward_decode_pcp_dcp(self, query: torch.Tensor, - attn_metadata: AscendMetadata) -> torch.Tensor: + def _forward_decode_pcp_dcp(self, query: torch.Tensor, attn_metadata: AscendMetadata) -> torch.Tensor: assert self.key_cache is not None assert self.value_cache is not None @@ -479,36 +466,23 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): else: num_heads = self.num_heads - k_nope = self.key_cache.view(self.key_cache.shape[0], - self.key_cache.shape[1], -1) - value = self.value_cache.view(self.key_cache.shape[0], - self.key_cache.shape[1], -1) + k_nope = self.key_cache.view(self.key_cache.shape[0], self.key_cache.shape[1], -1) + value = self.value_cache.view(self.key_cache.shape[0], self.key_cache.shape[1], -1) common_kwargs = { - 'num_heads': - num_heads, - 'num_key_value_heads': - self.num_kv_heads, - 'input_layout': - 'TND', - 'atten_mask': - None, - 'scale': - self.scale, - 'antiquant_mode': - 0, - 'antiquant_scale': - None, - 'softmax_lse_flag': - True, - 'block_table': - attn_metadata.decode_meta.block_tables, - 'block_size': - self.key_cache.shape[1], - 'actual_seq_lengths_kv': - attn_metadata.decode_meta. - num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank], - 'actual_seq_lengths': - attn_metadata.actual_seq_lengths_q[:attn_metadata.num_decodes], + "num_heads": num_heads, + "num_key_value_heads": self.num_kv_heads, + "input_layout": "TND", + "atten_mask": None, + "scale": self.scale, + "antiquant_mode": 0, + "antiquant_scale": None, + "softmax_lse_flag": True, + "block_table": attn_metadata.decode_meta.block_tables, + "block_size": self.key_cache.shape[1], + "actual_seq_lengths_kv": attn_metadata.decode_meta.num_computed_tokens_of_pcp_dcp[ + :, self.pcp_rank, self.dcp_rank + ], + "actual_seq_lengths": attn_metadata.actual_seq_lengths_q[: attn_metadata.num_decodes], } graph_params = get_graph_params() forward_context: ForwardContext = get_forward_context() @@ -524,44 +498,44 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): workspace = graph_params.workspaces.get(num_tokens) if workspace is None: workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace( - query, k_nope, value, **common_kwargs) - update_graph_params_workspaces(num_tokens, - weak_ref_tensors(workspace)) + query, k_nope, value, **common_kwargs + ) + update_graph_params_workspaces(num_tokens, weak_ref_tensors(workspace)) attn_out = torch.empty_like(query) - attn_lse = torch.empty((num_tokens, num_heads, 1), - dtype=torch.float, - device=query.device) + attn_lse = torch.empty((num_tokens, num_heads, 1), dtype=torch.float, device=query.device) - graph_params.attn_params[num_tokens].append(( - weak_ref_tensors(query), weak_ref_tensors(k_nope), - weak_ref_tensors(value), self.num_heads, self.num_kv_heads, - self.scale, attn_metadata.block_tables, - self.key_cache.shape[1], attn_metadata.decode_meta. - num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, - self.dcp_rank], - attn_metadata.actual_seq_lengths_q[:attn_metadata.num_decodes], - weak_ref_tensors(attn_out), weak_ref_tensors(attn_lse), - self.dcp_size, self.pcp_rank, self.dcp_rank)) + graph_params.attn_params[num_tokens].append( + ( + weak_ref_tensors(query), + weak_ref_tensors(k_nope), + weak_ref_tensors(value), + self.num_heads, + self.num_kv_heads, + self.scale, + attn_metadata.block_tables, + self.key_cache.shape[1], + attn_metadata.decode_meta.num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank], + attn_metadata.actual_seq_lengths_q[: attn_metadata.num_decodes], + weak_ref_tensors(attn_out), + weak_ref_tensors(attn_lse), + self.dcp_size, + self.pcp_rank, + self.dcp_rank, + ) + ) torch.npu.graph_task_group_begin(stream) torch_npu.npu_fused_infer_attention_score.out( - query, - k_nope, - value, - **common_kwargs, - workspace=workspace, - out=[attn_out, attn_lse]) + query, k_nope, value, **common_kwargs, workspace=workspace, out=[attn_out, attn_lse] + ) handle = torch.npu.graph_task_group_end(stream) graph_params.handles[num_tokens].append(handle) else: - attn_out, attn_lse = torch_npu.npu_fused_infer_attention_score( - query, k_nope, value, **common_kwargs) - attn_out_lse = _process_attn_out_lse( - attn_out, attn_lse, attn_metadata.decode_meta.batch_seq_mask) + attn_out, attn_lse = torch_npu.npu_fused_infer_attention_score(query, k_nope, value, **common_kwargs) + attn_out_lse = _process_attn_out_lse(attn_out, attn_lse, attn_metadata.decode_meta.batch_seq_mask) attn_out = _npu_attention_update(self.head_size, attn_out_lse) return attn_out - def _update_out_and_lse(self, out_list: torch.Tensor, - lse_list: torch.Tensor) -> torch.Tensor: + def _update_out_and_lse(self, out_list: torch.Tensor, lse_list: torch.Tensor) -> torch.Tensor: """LSE_final = log(sum(exp(LSE_i))), O_final = sum(exp(LSE_i - LSE_final) * O_i) Args: out_list: shape = [N, batch_size, num_heads, head_size] @@ -571,57 +545,58 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): lse_final: shape = [batch_size, num_heads, 1] """ lse_final = torch.logsumexp(lse_list, dim=0, keepdim=False) - out_final = torch.sum(torch.exp(lse_list - lse_final) * out_list, - dim=0) + out_final = torch.sum(torch.exp(lse_list - lse_final) * out_list, dim=0) return out_final, lse_final def _update_chunk_attn_out_lse_with_current_attn_out_lse( - self, current_attn_output_prefill, current_attn_lse_prefill, - attn_output_full_chunk, attn_lse_full_chunk, prefill_query, - attn_metadata): + self, + current_attn_output_prefill, + current_attn_lse_prefill, + attn_output_full_chunk, + attn_lse_full_chunk, + prefill_query, + attn_metadata, + ): if self.pcp_size > 1: inverse_idx = attn_metadata.prefill.chunked_context.kv_inverse_idx_for_chunk - attn_output_full_chunk = torch.index_select( - attn_output_full_chunk, 0, inverse_idx) - attn_lse_full_chunk = torch.index_select(attn_lse_full_chunk, 0, - inverse_idx) + attn_output_full_chunk = torch.index_select(attn_output_full_chunk, 0, inverse_idx) + attn_lse_full_chunk = torch.index_select(attn_lse_full_chunk, 0, inverse_idx) num_tokens = prefill_query.size(0) attn_output_full_chunk = attn_output_full_chunk[ - self.pcp_rank * num_tokens:(self.pcp_rank + 1) * num_tokens, :, :] - attn_lse_full_chunk = attn_lse_full_chunk[ - self.pcp_rank * num_tokens:(self.pcp_rank + 1) * num_tokens, :, :] + self.pcp_rank * num_tokens : (self.pcp_rank + 1) * num_tokens, :, : + ] + attn_lse_full_chunk = attn_lse_full_chunk[self.pcp_rank * num_tokens : (self.pcp_rank + 1) * num_tokens, :, :] - assert attn_output_full_chunk.shape == current_attn_output_prefill.shape and attn_lse_full_chunk.shape == current_attn_lse_prefill.shape + assert ( + attn_output_full_chunk.shape == current_attn_output_prefill.shape + and attn_lse_full_chunk.shape == current_attn_lse_prefill.shape + ) filtered_indices = attn_metadata.prefill.chunked_context.chunk_seq_mask_filtered_indices - attn_output_prefill_filtered = current_attn_output_prefill[ - filtered_indices, :, :] - attn_lse_prefill_filtered = current_attn_lse_prefill[ - filtered_indices, :, :] + attn_output_prefill_filtered = current_attn_output_prefill[filtered_indices, :, :] + attn_lse_prefill_filtered = current_attn_lse_prefill[filtered_indices, :, :] attn_output_full_chunk = attn_output_full_chunk[filtered_indices, :, :] attn_lse_full_chunk = attn_lse_full_chunk[filtered_indices, :, :] attn_output_filtered = self._npu_attn_out_lse_update( - attn_lse_prefill_filtered, attn_lse_full_chunk, - attn_output_prefill_filtered, attn_output_full_chunk) + attn_lse_prefill_filtered, attn_lse_full_chunk, attn_output_prefill_filtered, attn_output_full_chunk + ) - current_attn_output_prefill[ - filtered_indices, :, :] = attn_output_filtered.to( - current_attn_output_prefill.dtype) + current_attn_output_prefill[filtered_indices, :, :] = attn_output_filtered.to(current_attn_output_prefill.dtype) def _prefill_query_all_gather(self, attn_metadata, prefill_query): if self.pcp_size > 1: prefill_query = get_pcp_group().all_gather(prefill_query, 0) prefill_query = torch.index_select( - prefill_query, 0, attn_metadata.prefill.chunked_context. - cp_kv_recover_idx_for_chunk) + prefill_query, 0, attn_metadata.prefill.chunked_context.cp_kv_recover_idx_for_chunk + ) if self.dcp_size > 1: prefill_query = get_dcp_group().all_gather(prefill_query, 1) return prefill_query - def _compute_prefill_context(self, query: torch.Tensor, - kv_cache: Tuple[torch.Tensor], - attn_metadata: AscendMetadata): + def _compute_prefill_context( + self, query: torch.Tensor, kv_cache: tuple[torch.Tensor], attn_metadata: AscendMetadata + ): assert len(kv_cache) > 1 assert attn_metadata is not None assert attn_metadata.prefill is not None @@ -630,26 +605,23 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): local_chunked_kv_lens = prefill_metadata.chunked_context.local_context_lens_allranks assert local_chunked_kv_lens is not None - local_chunked_kv_lens_rank = local_chunked_kv_lens[:, self.pcp_rank, - self.dcp_rank] + local_chunked_kv_lens_rank = local_chunked_kv_lens[:, self.pcp_rank, self.dcp_rank] total_toks = prefill_metadata.chunked_context.local_total_toks - key, value = self._load_kv_for_chunk(attn_metadata, kv_cache, - local_chunked_kv_lens_rank, query, - total_toks) + key, value = self._load_kv_for_chunk(attn_metadata, kv_cache, local_chunked_kv_lens_rank, query, total_toks) if self.dcp_size > 1: num_heads = self.num_heads * self.dcp_size else: num_heads = self.num_heads if total_toks == 0: - return (torch.full((query.size(0), num_heads, self.head_size), - fill_value=0, - dtype=query.dtype, - device=query.device), - torch.full((query.size(0), num_heads, 1), - fill_value=-torch.inf, - dtype=torch.float32, - device=query.device)) + return ( + torch.full( + (query.size(0), num_heads, self.head_size), fill_value=0, dtype=query.dtype, device=query.device + ), + torch.full( + (query.size(0), num_heads, 1), fill_value=-torch.inf, dtype=torch.float32, device=query.device + ), + ) prefix_chunk_output, prefix_chunk_lse = torch.ops.npu.npu_fused_infer_attention_score( query, @@ -664,42 +636,31 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): antiquant_mode=0, antiquant_scale=None, softmax_lse_flag=True, - actual_seq_lengths_kv=prefill_metadata.chunked_context. - actual_seq_lengths_kv, - actual_seq_lengths=attn_metadata.prefill.chunked_context. - actual_chunk_seq_lengths) + actual_seq_lengths_kv=prefill_metadata.chunked_context.actual_seq_lengths_kv, + actual_seq_lengths=attn_metadata.prefill.chunked_context.actual_chunk_seq_lengths, + ) batch_chunk_seq_mask = attn_metadata.prefill.chunked_context.batch_chunk_seq_mask - lse_mask = batch_chunk_seq_mask[:, None, - None].expand_as(prefix_chunk_lse) + lse_mask = batch_chunk_seq_mask[:, None, None].expand_as(prefix_chunk_lse) prefix_chunk_lse = torch.where(lse_mask, -torch.inf, prefix_chunk_lse) return prefix_chunk_output, prefix_chunk_lse - def _load_kv_for_chunk(self, attn_metadata, kv_cache, - local_chunked_kv_lens_rank, query, total_toks): + def _load_kv_for_chunk(self, attn_metadata, kv_cache, local_chunked_kv_lens_rank, query, total_toks): cache_key = kv_cache[0] cache_value = kv_cache[1] num_heads = cache_key.size(2) head_size = kv_cache[0].size(-1) - key = torch.empty(total_toks, - num_heads, - head_size, - dtype=query.dtype, - device=query.device) - value = torch.empty(total_toks, - num_heads, - head_size, - dtype=query.dtype, - device=query.device) + key = torch.empty(total_toks, num_heads, head_size, dtype=query.dtype, device=query.device) + value = torch.empty(total_toks, num_heads, head_size, dtype=query.dtype, device=query.device) if total_toks > 0: torch_npu.atb.npu_paged_cache_load( cache_key, cache_value, attn_metadata.prefill.block_tables, local_chunked_kv_lens_rank, - seq_starts=attn_metadata.prefill.chunked_context. - starts, # slot offsets of current chunk in current iteration + # slot offsets of current chunk in current iteration + seq_starts=attn_metadata.prefill.chunked_context.starts, key=key, value=value, ) @@ -709,10 +670,9 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): self, key: torch.Tensor, value: torch.Tensor, - kv_cache: Tuple[torch.Tensor], + kv_cache: tuple[torch.Tensor], attn_metadata: AscendMetadata, ): - num_decode_tokens = attn_metadata.num_decode_tokens has_decode = attn_metadata.num_decodes > 0 has_prefill = attn_metadata.num_prefills > 0 @@ -722,60 +682,50 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): self.key_cache, self.value_cache = kv_cache[0], kv_cache[1] if has_decode: - slot_mapping = attn_metadata.slot_mapping[:num_decode_tokens * - self.pcp_size:self. - pcp_size] + slot_mapping = attn_metadata.slot_mapping[: num_decode_tokens * self.pcp_size : self.pcp_size] torch_npu._npu_reshape_and_cache( key=key[:num_decode_tokens], value=value[:num_decode_tokens], key_cache=self.key_cache, value_cache=self.value_cache, - slot_indices=slot_mapping) + slot_indices=slot_mapping, + ) if has_prefill: if self.pcp_size > 1: kv = torch.cat([key, value], dim=-1) num_actual_tokens_pcp_padded = attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size - all_kv = get_pcp_group().all_gather( - kv[:num_actual_tokens_pcp_padded].contiguous(), dim=0) + all_kv = get_pcp_group().all_gather(kv[:num_actual_tokens_pcp_padded].contiguous(), dim=0) assert attn_metadata.prefill is not None assert attn_metadata.prefill.pcp_metadata is not None pcp_allgather_restore_idx = attn_metadata.prefill.pcp_metadata.pcp_allgather_restore_idx - all_kv = torch.index_select(all_kv, 0, - pcp_allgather_restore_idx) - key, value = all_kv.split([self.head_size, self.head_size], - dim=-1) - prefill_key = key[self.pcp_size * - num_decode_tokens:attn_metadata. - num_actual_tokens_pcp_padded] - prefill_value = value[self.pcp_size * - num_decode_tokens:attn_metadata. - num_actual_tokens_pcp_padded] + all_kv = torch.index_select(all_kv, 0, pcp_allgather_restore_idx) + key, value = all_kv.split([self.head_size, self.head_size], dim=-1) + prefill_key = key[self.pcp_size * num_decode_tokens : attn_metadata.num_actual_tokens_pcp_padded] + prefill_value = value[self.pcp_size * num_decode_tokens : attn_metadata.num_actual_tokens_pcp_padded] slot_mapping = attn_metadata.slot_mapping[ - self.pcp_size * num_decode_tokens:attn_metadata. - num_actual_tokens_pcp_padded] - torch_npu._npu_reshape_and_cache(key=prefill_key, - value=prefill_value, - key_cache=self.key_cache, - value_cache=self.value_cache, - slot_indices=slot_mapping) + self.pcp_size * num_decode_tokens : attn_metadata.num_actual_tokens_pcp_padded + ] + torch_npu._npu_reshape_and_cache( + key=prefill_key, + value=prefill_value, + key_cache=self.key_cache, + value_cache=self.value_cache, + slot_indices=slot_mapping, + ) return key, value def _gather_global_context_output(self, local_context_attn_output): if self.dcp_size > 1: - dcp_context_attn_output = torch.empty_like( - local_context_attn_output) - dist.all_to_all_single(dcp_context_attn_output, - local_context_attn_output, - group=self.dcp_group) + dcp_context_attn_output = torch.empty_like(local_context_attn_output) + dist.all_to_all_single(dcp_context_attn_output, local_context_attn_output, group=self.dcp_group) else: dcp_context_attn_output = local_context_attn_output if self.pcp_size > 1: # AllGather out&lse within CP group - global_context_attn_output = get_pcp_group().all_gather( - dcp_context_attn_output, dim=-1) + global_context_attn_output = get_pcp_group().all_gather(dcp_context_attn_output, dim=-1) else: global_context_attn_output = dcp_context_attn_output @@ -788,17 +738,14 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): D = self.head_size assert D_plus_1 == D + 1 # [PCP, S, DCP, H, D+1] - x = global_context_output.view(self.pcp_size, S, self.dcp_size, H, - D_plus_1) + x = global_context_output.view(self.pcp_size, S, self.dcp_size, H, D_plus_1) # [PCP, DCP, S, H, D+1] x = x.permute(0, 2, 1, 3, 4).contiguous() # Flatten [N, S, H, D+1], N = pcp_size * dcp_size x = x.view(-1, S, H, D_plus_1) # Split out lse - attn_out_allgather, attn_lse_allgather = torch.split( - x, [D, 1], dim=-1) # [N, S, H, D], [N, S, H, 1] - context_output, context_lse = self._update_out_and_lse( - attn_out_allgather, attn_lse_allgather) + attn_out_allgather, attn_lse_allgather = torch.split(x, [D, 1], dim=-1) # [N, S, H, D], [N, S, H, 1] + context_output, context_lse = self._update_out_and_lse(attn_out_allgather, attn_lse_allgather) return context_output, context_lse def forward_impl( @@ -806,7 +753,7 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - kv_cache: Tuple[torch.Tensor], + kv_cache: tuple[torch.Tensor], attn_metadata: AscendMetadata, output: torch.Tensor, ) -> torch.Tensor: @@ -816,8 +763,7 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): num_decode_tokens = attn_metadata.num_decode_tokens if has_decode: decode_query = query[:num_decode_tokens] - output_decode = self._forward_decode_pcp_dcp( - decode_query, attn_metadata) + output_decode = self._forward_decode_pcp_dcp(decode_query, attn_metadata) output[:num_decode_tokens] = output_decode if has_prefill: assert attn_metadata.prefill is not None @@ -833,26 +779,21 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): # qkv init num_actual_tokens_pcp_padded = attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size - prefill_query = query[ - num_decode_tokens:num_actual_tokens_pcp_padded].contiguous() - key = key[self.pcp_size * num_decode_tokens:].contiguous() - value = value[self.pcp_size * num_decode_tokens:].contiguous() + prefill_query = query[num_decode_tokens:num_actual_tokens_pcp_padded].contiguous() + key = key[self.pcp_size * num_decode_tokens :].contiguous() + value = value[self.pcp_size * num_decode_tokens :].contiguous() if has_chunked_context: # all_gather q for chunked prefill // overlap the computation inner current chunk - cp_chunkedprefill_comm_stream().wait_stream( - torch.npu.current_stream()) + cp_chunkedprefill_comm_stream().wait_stream(torch.npu.current_stream()) with torch_npu.npu.stream(cp_chunkedprefill_comm_stream()): - prefill_query_all = self._prefill_query_all_gather( - attn_metadata, prefill_query.clone()) + prefill_query_all = self._prefill_query_all_gather(attn_metadata, prefill_query.clone()) if self.pcp_size > 1: # Scenario of Enabling PCP or PCP&DCP # prepare qkv and compute the head part // overlap the communication of all gather q - data_head, data_tail = self._forward_prefill_cp_pre( - prefill_query, key, value, attn_metadata) - output_head, lse_head = self._forward_prefill_cp_attn( - data_head, True, attn_metadata) + data_head, data_tail = self._forward_prefill_cp_pre(prefill_query, key, value, attn_metadata) + output_head, lse_head = self._forward_prefill_cp_attn(data_head, True, attn_metadata) else: # Scenario of Enabling DCP Individually attn_output_prefill, attn_lse_prefill = torch.ops.npu.npu_fused_infer_attention_score( @@ -868,32 +809,25 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): antiquant_mode=0, antiquant_scale=None, softmax_lse_flag=True, - actual_seq_lengths_kv=attn_metadata.prefill. - actual_seq_lengths_q, - actual_seq_lengths=attn_metadata.prefill. - actual_seq_lengths_q) + actual_seq_lengths_kv=attn_metadata.prefill.actual_seq_lengths_q, + actual_seq_lengths=attn_metadata.prefill.actual_seq_lengths_q, + ) if has_chunked_context: - torch.npu.current_stream().wait_stream( - cp_chunkedprefill_comm_stream()) + torch.npu.current_stream().wait_stream(cp_chunkedprefill_comm_stream()) # computation of context - context_output = self._compute_prefill_context( - prefill_query_all, kv_cache, attn_metadata) + context_output = self._compute_prefill_context(prefill_query_all, kv_cache, attn_metadata) # Note(qcs): (output, lse) -> [Seq, Head_num, Head_dim+1] -> [Head_num, Head_dim+1, Seq] - local_context_output = torch.cat( - context_output, dim=-1).permute([1, 2, 0]).contiguous() + local_context_output = torch.cat(context_output, dim=-1).permute([1, 2, 0]).contiguous() # all2all and all_gather output&lse // overlap the computation inner current chunk - cp_chunkedprefill_comm_stream().wait_stream( - torch.npu.current_stream()) + cp_chunkedprefill_comm_stream().wait_stream(torch.npu.current_stream()) with torch_npu.npu.stream(cp_chunkedprefill_comm_stream()): - global_context_output = self._gather_global_context_output( - local_context_output) + global_context_output = self._gather_global_context_output(local_context_output) if self.pcp_size > 1: # compute the tail part and reorg output&lse // overlap the communication of output - output_tail, lse_tail = self._forward_prefill_cp_attn( - data_tail, False, attn_metadata) + output_tail, lse_tail = self._forward_prefill_cp_attn(data_tail, False, attn_metadata) attn_output_prefill, attn_lse_prefill = self._forward_prefill_cp_post( [output_head, output_tail], @@ -903,16 +837,12 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): if attn_metadata.prefill is not None and attn_metadata.prefill.chunked_context is not None: # update the output of current chunk with context part - torch.npu.current_stream().wait_stream( - cp_chunkedprefill_comm_stream()) - global_context_output = global_context_output.permute( - [2, 0, 1]).contiguous() - context_output, context_lse = self._update_global_context_output( - global_context_output) + torch.npu.current_stream().wait_stream(cp_chunkedprefill_comm_stream()) + global_context_output = global_context_output.permute([2, 0, 1]).contiguous() + context_output, context_lse = self._update_global_context_output(global_context_output) self._update_chunk_attn_out_lse_with_current_attn_out_lse( - attn_output_prefill, attn_lse_prefill, context_output, - context_lse, prefill_query, attn_metadata) + attn_output_prefill, attn_lse_prefill, context_output, context_lse, prefill_query, attn_metadata + ) - output[num_decode_tokens:attn_output_prefill.shape[0] + - num_decode_tokens] = attn_output_prefill + output[num_decode_tokens : attn_output_prefill.shape[0] + num_decode_tokens] = attn_output_prefill return output diff --git a/vllm_ascend/attention/context_parallel/common_cp.py b/vllm_ascend/attention/context_parallel/common_cp.py index 89debf11..65103059 100644 --- a/vllm_ascend/attention/context_parallel/common_cp.py +++ b/vllm_ascend/attention/context_parallel/common_cp.py @@ -1,12 +1,9 @@ from dataclasses import dataclass -from typing import Optional import torch import torch.distributed as dist import torch_npu -from vllm.distributed import (get_dcp_group, - get_decode_context_model_parallel_world_size, - get_pcp_group) +from vllm.distributed import get_dcp_group, get_decode_context_model_parallel_world_size, get_pcp_group @dataclass @@ -17,6 +14,7 @@ class AscendPCPMetadata: Stores index tensors and sequence lengths for routing attention computations across PCP ranks during long sequence processing. """ + q_head_idx: torch.Tensor = None q_tail_idx: torch.Tensor = None kv_with_q_head_nomask_idx: torch.Tensor = None @@ -27,7 +25,7 @@ class AscendPCPMetadata: head_attn_nomask_seqlens: torch.Tensor = None tail_attn_nomask_seqlens: torch.Tensor = None q_full_idx: torch.Tensor = None - pcp_allgather_restore_idx: Optional[list[int]] = None + pcp_allgather_restore_idx: list[int] | None = None @dataclass @@ -37,6 +35,7 @@ class CPChunkedContextMetadata: Extends chunked prefill with per-rank chunk information for PCP/DCP. """ + # For handling chunked prefill cu_seq_lens: torch.Tensor starts: torch.Tensor @@ -47,48 +46,51 @@ class CPChunkedContextMetadata: chunk_seq_lens_npu: torch.Tensor # for mla DCP & PCP padded_chunk_seq_lens_npu: torch.Tensor = None - padded_local_chunk_seq_lens: Optional[list[list[int]]] = None - local_context_lens_allranks: Optional[list[list[int]]] = None + padded_local_chunk_seq_lens: list[list[int]] | None = None + local_context_lens_allranks: list[list[int]] | None = None padded_local_cu_seq_lens: torch.Tensor = None - cu_seq_lens_lst: Optional[list[list[int]]] = None - chunk_size: Optional[int] = None + cu_seq_lens_lst: list[list[int]] | None = None + chunk_size: int | None = None @dataclass class AscendMetadataForPrefill: - """ Prefill-specific metadata for Ascend attention with Context Parallelism.""" + """Prefill-specific metadata for Ascend attention with Context Parallelism.""" @dataclass class ChunkedContextMetadata: """Metadata for chunked context processing within prefill phase.""" + actual_chunk_seq_lengths: torch.Tensor actual_seq_lengths_kv: torch.Tensor starts: torch.Tensor chunk_seq_mask_filtered_indices: torch.Tensor - chunked_req_mask: Optional[list[bool]] = None - local_context_lens_allranks: Optional[list[list[int]]] = None - cp_kv_recover_idx_for_chunk: Optional[list[int]] = None - kv_inverse_idx_for_chunk: Optional[list[int]] = None - batch_chunk_seq_mask: Optional[list[bool]] = None - local_total_toks: Optional[int] = None + chunked_req_mask: list[bool] | None = None + local_context_lens_allranks: list[list[int]] | None = None + cp_kv_recover_idx_for_chunk: list[int] | None = None + kv_inverse_idx_for_chunk: list[int] | None = None + batch_chunk_seq_mask: list[bool] | None = None + local_total_toks: int | None = None """ Prefill Specific Metadata for Ascend""" - pcp_metadata: Optional[AscendPCPMetadata] = None - chunked_context: Optional[ChunkedContextMetadata] = None + pcp_metadata: AscendPCPMetadata | None = None + chunked_context: ChunkedContextMetadata | None = None block_tables: torch.Tensor = None actual_seq_lengths_q: torch.Tensor = None @dataclass class AscendMetadataForDecode: - """ Decode-specific metadata for Ascend attention with Context Parallelism.""" - num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None + """Decode-specific metadata for Ascend attention with Context Parallelism.""" + + num_computed_tokens_of_pcp_dcp: list[list[list[int]]] | None = None batch_seq_mask: torch.Tensor = None block_tables: torch.Tensor = None -def _process_attn_out_lse(attn_output: torch.Tensor, softmax_lse: torch.Tensor, - batch_seq_mask: torch.Tensor) -> torch.Tensor: +def _process_attn_out_lse( + attn_output: torch.Tensor, softmax_lse: torch.Tensor, batch_seq_mask: torch.Tensor +) -> torch.Tensor: pcp_size = get_pcp_group().world_size dcp_size = get_decode_context_model_parallel_world_size() dcp_group = get_dcp_group().device_group if dcp_size > 1 else None @@ -104,21 +106,17 @@ def _process_attn_out_lse(attn_output: torch.Tensor, softmax_lse: torch.Tensor, # permute: [bs, num_heads, v_head_dim+1] -> [num_heads, v_head_dim+1, bs] attn_out_lse = attn_out_lse.permute([1, 2, 0]).contiguous() attn_out_lse_all2all = torch.empty_like(attn_out_lse) - dist.all_to_all_single(attn_out_lse_all2all, - attn_out_lse, - group=dcp_group) + dist.all_to_all_single(attn_out_lse_all2all, attn_out_lse, group=dcp_group) attn_out_lse = attn_out_lse_all2all.permute([2, 0, 1]) if pcp_size > 1: # AllGather out&lse within CP group - attn_out_lse = get_pcp_group().all_gather(attn_out_lse.contiguous(), - dim=0) + attn_out_lse = get_pcp_group().all_gather(attn_out_lse.contiguous(), dim=0) return attn_out_lse -def _npu_attention_update(head_size, - attn_out_lse: torch.Tensor) -> torch.Tensor: +def _npu_attention_update(head_size, attn_out_lse: torch.Tensor) -> torch.Tensor: pcp_size = get_pcp_group().world_size dcp_size = get_decode_context_model_parallel_world_size() # [PCP * S, DCP * H, D+1] @@ -134,8 +132,7 @@ def _npu_attention_update(head_size, # Flatten [N, S, H, D+1], N = pcp_size * dcp_size x = x.view(-1, S, H, D_plus_1) # Split out lse - out_flat, lse_flat = torch.split(x, [D, 1], - dim=-1) # [N, S, H, D], [N, S, H, 1] + out_flat, lse_flat = torch.split(x, [D, 1], dim=-1) # [N, S, H, D], [N, S, H, 1] # out: [N, S, H, D] -> [N, S*H, D] # lse: [N, S, H, 1] -> [N, S*H] out_flat = out_flat.flatten(1, 2) # [N, S*H, D] diff --git a/vllm_ascend/attention/context_parallel/mla_cp.py b/vllm_ascend/attention/context_parallel/mla_cp.py index 7740092f..7249814c 100644 --- a/vllm_ascend/attention/context_parallel/mla_cp.py +++ b/vllm_ascend/attention/context_parallel/mla_cp.py @@ -1,35 +1,43 @@ -from typing import Optional, Tuple, TypeVar +from typing import TypeVar import numpy as np import torch import torch_npu from vllm.config import VllmConfig -from vllm.distributed import (get_dcp_group, - get_decode_context_model_parallel_rank, - get_decode_context_model_parallel_world_size, - get_pcp_group) +from vllm.distributed import ( + get_dcp_group, + get_decode_context_model_parallel_rank, + get_decode_context_model_parallel_world_size, + get_pcp_group, +) from vllm.forward_context import ForwardContext, get_forward_context from vllm.utils.math_utils import cdiv from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec # isort: off from vllm_ascend.attention.mla_v1 import ( - AscendMLADecodeMetadata, AscendMLAImpl, AscendMLAMetadata, - AscendMLAMetadataBuilder, AscendMLAPrefillMetadata, - DecodeMLAPreprocessResult, PrefillMLAPreprocessResult, - BUILD_METADATA_STEP_PREFILL) -#isort: on + AscendMLADecodeMetadata, + AscendMLAImpl, + AscendMLAMetadata, + AscendMLAMetadataBuilder, + AscendMLAPrefillMetadata, + DecodeMLAPreprocessResult, + PrefillMLAPreprocessResult, + BUILD_METADATA_STEP_PREFILL, +) +# isort: on -from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata) from vllm_ascend.attention.context_parallel.common_cp import ( - AscendPCPMetadata, CPChunkedContextMetadata, _process_attn_out_lse, - _npu_attention_update) -from vllm_ascend.compilation.acl_graph import (get_draft_graph_params, - get_graph_params, - update_graph_params_workspaces) -from vllm_ascend.utils import weak_ref_tensors, vllm_version_is + AscendPCPMetadata, + CPChunkedContextMetadata, + _npu_attention_update, + _process_attn_out_lse, +) +from vllm_ascend.attention.utils import AscendCommonAttentionMetadata +from vllm_ascend.compilation.acl_graph import get_draft_graph_params, get_graph_params, update_graph_params_workspaces +from vllm_ascend.utils import vllm_version_is, weak_ref_tensors -if vllm_version_is('0.13.0'): +if vllm_version_is("0.13.0"): from vllm.v1.attention.backends.utils import AttentionCGSupport else: from vllm.v1.attention.backend import AttentionCGSupport @@ -54,28 +62,21 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder): metadata_cls: type[AscendMLAMetadata] | None = None, supports_dcp_with_varlen: bool = False, ): - super().__init__(kv_cache_spec, layer_names, vllm_config, device, - metadata_cls, supports_dcp_with_varlen) + super().__init__(kv_cache_spec, layer_names, vllm_config, device, metadata_cls, supports_dcp_with_varlen) self.pcp_size = get_pcp_group().world_size - self.pcp_rank = get_pcp_group( - ).rank_in_group if self.pcp_size > 1 else 0 + self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0 self.dcp_size = get_decode_context_model_parallel_world_size() - self.dcp_rank = get_decode_context_model_parallel_rank( - ) if self.dcp_size > 1 else 0 + self.dcp_rank = get_decode_context_model_parallel_rank() if self.dcp_size > 1 else 0 self.cp_local_block_size = vllm_config.parallel_config.cp_kv_cache_interleave_size self.cp_virtual_block_size = self.cp_local_block_size * self.dcp_size * self.pcp_size scheduler_config = vllm_config.scheduler_config - decode_max_num_seqs = getattr(scheduler_config, 'decode_max_num_seqs', - 0) + decode_max_num_seqs = getattr(scheduler_config, "decode_max_num_seqs", 0) max_num_seqs = max(scheduler_config.max_num_seqs, decode_max_num_seqs) - self.batch_seq_mask_buf = torch.empty(max_num_seqs * - self.decode_threshold, - dtype=torch.uint8, - device=device) - self.block_size = (self.block_size * - self.cp_virtual_block_size) // np.gcd( - self.block_size, self.cp_virtual_block_size) + self.batch_seq_mask_buf = torch.empty(max_num_seqs * self.decode_threshold, dtype=torch.uint8, device=device) + self.block_size = (self.block_size * self.cp_virtual_block_size) // np.gcd( + self.block_size, self.cp_virtual_block_size + ) def build( self, @@ -85,15 +86,10 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder): ) -> AscendMLAMetadata: metadata_cls = super().build(common_prefix_len, common_attn_metadata) if self.num_prefills == 0 and self.pcp_size > 1: - self.slot_mapping[:self. - num_decode_tokens] = self.slot_mapping[:self. - num_decode_tokens - * self. - pcp_size: - self. - pcp_size] - self.slot_mapping[self.num_decode_tokens:self.num_decode_tokens * - self.pcp_size].fill_(-1) + self.slot_mapping[: self.num_decode_tokens] = self.slot_mapping[ + : self.num_decode_tokens * self.pcp_size : self.pcp_size + ] + self.slot_mapping[self.num_decode_tokens : self.num_decode_tokens * self.pcp_size].fill_(-1) metadata_cls.slot_mapping = self.slot_mapping return metadata_cls @@ -118,8 +114,8 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder): # In dcp only spec decode graph padding case, # num_actual_tokens_pcp_padded may be less than num_actual_tokens self.num_actual_tokens = max( - long_seq_metadata.num_actual_tokens_pcp_padded, - common_attn_metadata.num_actual_tokens) + long_seq_metadata.num_actual_tokens_pcp_padded, common_attn_metadata.num_actual_tokens + ) def build_cp_metadata( self, @@ -131,30 +127,23 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder): return AscendPCPMetadata( q_head_idx=common_long_seq_metadata.q_head_idx_tensor, q_tail_idx=common_long_seq_metadata.q_tail_idx_tensor, - kv_with_q_head_nomask_idx=common_long_seq_metadata. - kv_with_q_head_nomask_idx_tensor, - kv_with_q_head_mask_idx=common_long_seq_metadata. - kv_with_q_head_mask_idx_tensor, - kv_with_q_tail_nomask_idx=common_long_seq_metadata. - kv_with_q_tail_nomask_idx_tensor, - kv_with_q_tail_mask_idx=common_long_seq_metadata. - kv_with_q_tail_mask_idx_tensor, + kv_with_q_head_nomask_idx=common_long_seq_metadata.kv_with_q_head_nomask_idx_tensor, + kv_with_q_head_mask_idx=common_long_seq_metadata.kv_with_q_head_mask_idx_tensor, + kv_with_q_tail_nomask_idx=common_long_seq_metadata.kv_with_q_tail_nomask_idx_tensor, + kv_with_q_tail_mask_idx=common_long_seq_metadata.kv_with_q_tail_mask_idx_tensor, attn_mask_seqlens=common_long_seq_metadata.attn_mask_seqlens, - head_attn_nomask_seqlens=common_long_seq_metadata. - head_attn_nomask_seqlens, - tail_attn_nomask_seqlens=common_long_seq_metadata. - tail_attn_nomask_seqlens, + head_attn_nomask_seqlens=common_long_seq_metadata.head_attn_nomask_seqlens, + tail_attn_nomask_seqlens=common_long_seq_metadata.tail_attn_nomask_seqlens, q_full_idx=common_long_seq_metadata.q_full_idx, - pcp_allgather_restore_idx=common_long_seq_metadata. - pcp_allgather_restore_idx) + pcp_allgather_restore_idx=common_long_seq_metadata.pcp_allgather_restore_idx, + ) def build_chunked_metadata( self, common_prefix_len: int, common_attn_metadata: AscendCommonAttentionMetadata, ): - chunked_context_metadata = super().build_chunked_metadata( - common_prefix_len, common_attn_metadata) + chunked_context_metadata = super().build_chunked_metadata(common_prefix_len, common_attn_metadata) if chunked_context_metadata is None: return None @@ -162,33 +151,37 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder): assert long_seq_metadata is not None num_computed_tokens_of_pcp_dcp = long_seq_metadata.num_computed_tokens_of_pcp_dcp assert num_computed_tokens_of_pcp_dcp is not None - local_context_lens_allranks = torch.tensor( - num_computed_tokens_of_pcp_dcp[self.num_decodes_flatten:]).reshape( - -1, self.dcp_size * self.pcp_size) + local_context_lens_allranks = torch.tensor(num_computed_tokens_of_pcp_dcp[self.num_decodes_flatten :]).reshape( + -1, self.dcp_size * self.pcp_size + ) # Note(qcs): The max local context lengths # padded to `cp_local_block_size`. - padded_local_context_lens_cpu = (cdiv( - self.context_lens_cpu, - self.cp_virtual_block_size, - ) * self.cp_local_block_size) - padded_local_max_context_chunk_across_ranks = (cdiv( - self.max_context_chunk, - self.cp_virtual_block_size, - ) * self.cp_local_block_size) - local_chunk_starts = (torch.arange( - self.num_chunks, dtype=torch.int32).unsqueeze(1).expand( - -1, self.num_prefills) * - padded_local_max_context_chunk_across_ranks) + padded_local_context_lens_cpu = ( + cdiv( + self.context_lens_cpu, + self.cp_virtual_block_size, + ) + * self.cp_local_block_size + ) + padded_local_max_context_chunk_across_ranks = ( + cdiv( + self.max_context_chunk, + self.cp_virtual_block_size, + ) + * self.cp_local_block_size + ) + local_chunk_starts = ( + torch.arange(self.num_chunks, dtype=torch.int32).unsqueeze(1).expand(-1, self.num_prefills) + * padded_local_max_context_chunk_across_ranks + ) local_chunk_ends = torch.min( padded_local_context_lens_cpu.unsqueeze(0), local_chunk_starts + padded_local_max_context_chunk_across_ranks, ) - padded_local_chunk_seq_lens = (local_chunk_ends - - local_chunk_starts).clamp(min=0) - padded_local_cu_chunk_seq_lens_cpu = torch.zeros(self.num_chunks, - self.num_prefills + 1, - dtype=torch.int32, - pin_memory=True) + padded_local_chunk_seq_lens = (local_chunk_ends - local_chunk_starts).clamp(min=0) + padded_local_cu_chunk_seq_lens_cpu = torch.zeros( + self.num_chunks, self.num_prefills + 1, dtype=torch.int32, pin_memory=True + ) torch.cumsum( padded_local_chunk_seq_lens, dim=1, @@ -197,8 +190,7 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder): ) chunked_metadata = CPChunkedContextMetadata( cu_seq_lens=chunked_context_metadata.cu_seq_lens, - starts=local_chunk_starts.pin_memory().to(self.device, - non_blocking=True), + starts=local_chunk_starts.pin_memory().to(self.device, non_blocking=True), seq_tot=padded_local_chunk_seq_lens.sum(dim=1).tolist(), max_seq_lens=chunked_context_metadata.max_seq_lens, chunk_seq_lens=self.chunk_seq_lens, @@ -207,18 +199,14 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder): padded_chunk_seq_lens_npu=padded_local_chunk_seq_lens.npu(), padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.tolist(), local_context_lens_allranks=local_context_lens_allranks.tolist(), - padded_local_cu_seq_lens=padded_local_cu_chunk_seq_lens_cpu. - pin_memory().to(self.device, non_blocking=True), + padded_local_cu_seq_lens=padded_local_cu_chunk_seq_lens_cpu.pin_memory().to(self.device, non_blocking=True), cu_seq_lens_lst=self.cu_seq_lens_cpu.tolist(), chunk_size=padded_local_max_context_chunk_across_ranks, ) return chunked_metadata - def get_block_table_size( - self, common_attn_metadata: AscendCommonAttentionMetadata, - build_metadata_step: int): - self.num_decodes_flatten = self.query_lens[:self.num_decodes].sum( - ).item() + def get_block_table_size(self, common_attn_metadata: AscendCommonAttentionMetadata, build_metadata_step: int): + self.num_decodes_flatten = self.query_lens[: self.num_decodes].sum().item() if build_metadata_step == BUILD_METADATA_STEP_PREFILL: # For pcp + spec decode, we flatten seq_lens and block_table # to avoid irregular attn_mask shape @@ -231,12 +219,9 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder): common_prefix_len: int, common_attn_metadata: AscendCommonAttentionMetadata, ) -> AscendMLAPrefillMetadata: - prefill_metadata = super().build_prefill_metadata( - common_prefix_len, common_attn_metadata) - prefill_metadata.pcp_metadata = self.build_cp_metadata( - common_prefix_len, common_attn_metadata) - prefill_metadata.block_table = self.block_table[ - self.num_decodes_flatten:, ...] + prefill_metadata = super().build_prefill_metadata(common_prefix_len, common_attn_metadata) + prefill_metadata.pcp_metadata = self.build_cp_metadata(common_prefix_len, common_attn_metadata) + prefill_metadata.block_table = self.block_table[self.num_decodes_flatten :, ...] return prefill_metadata def build_decode_metadata( @@ -244,24 +229,20 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder): common_prefix_len: int, common_attn_metadata: AscendCommonAttentionMetadata, ) -> AscendMLADecodeMetadata: - decode_metadata = super().build_decode_metadata( - common_prefix_len, common_attn_metadata) + decode_metadata = super().build_decode_metadata(common_prefix_len, common_attn_metadata) long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata assert long_seq_metadata is not None num_computed_tokens_of_pcp_dcp = long_seq_metadata.num_computed_tokens_of_pcp_dcp assert num_computed_tokens_of_pcp_dcp is not None # [bs, pcp_size, dcp_size] - num_computed_tokens_of_cp_dcp_array = np.array( - num_computed_tokens_of_pcp_dcp)[:self.num_decodes_flatten] + num_computed_tokens_of_cp_dcp_array = np.array(num_computed_tokens_of_pcp_dcp)[: self.num_decodes_flatten] - cp_seq_len = num_computed_tokens_of_cp_dcp_array[:, self.pcp_rank, - self.dcp_rank] + cp_seq_len = num_computed_tokens_of_cp_dcp_array[:, self.pcp_rank, self.dcp_rank] cp_seq_len = torch.tensor(cp_seq_len, dtype=torch.int32) - batch_seq_mask = (cp_seq_len == 0) - self.batch_seq_mask_buf[:batch_seq_mask.shape[0]].copy_( - batch_seq_mask, non_blocking=True) - batch_seq_mask = self.batch_seq_mask_buf[:batch_seq_mask.shape[0]] + batch_seq_mask = cp_seq_len == 0 + self.batch_seq_mask_buf[: batch_seq_mask.shape[0]].copy_(batch_seq_mask, non_blocking=True) + batch_seq_mask = self.batch_seq_mask_buf[: batch_seq_mask.shape[0]] cp_seq_len = torch.where(cp_seq_len == 0, 1, cp_seq_len) decode_metadata.cp_seq_len = cp_seq_len decode_metadata.batch_seq_mask = batch_seq_mask @@ -280,30 +261,35 @@ class AscendMlaCPImpl(AscendMLAImpl): head_size: int, scale: float, num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], + alibi_slopes: list[float] | None, + sliding_window: int | None, kv_cache_dtype: str, - logits_soft_cap: Optional[float], + logits_soft_cap: float | None, attn_type: str, - kv_sharing_target_layer_name: Optional[str], + kv_sharing_target_layer_name: str | None, **kwargs, ): - super().__init__(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window, kv_cache_dtype, - logits_soft_cap, attn_type, - kv_sharing_target_layer_name, **kwargs) + super().__init__( + num_heads, + head_size, + scale, + num_kv_heads, + alibi_slopes, + sliding_window, + kv_cache_dtype, + logits_soft_cap, + attn_type, + kv_sharing_target_layer_name, + **kwargs, + ) self.pcp_size = get_pcp_group().world_size - self.pcp_rank = get_pcp_group( - ).rank_in_group if self.pcp_size > 1 else 0 - self.pcp_group = get_pcp_group( - ).device_group if self.pcp_size > 1 else None + self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0 + self.pcp_group = get_pcp_group().device_group if self.pcp_size > 1 else None self.dcp_size = get_decode_context_model_parallel_world_size() - self.dcp_rank = get_decode_context_model_parallel_rank( - ) if self.dcp_size > 1 else 0 - self.dcp_group = get_dcp_group( - ).device_group if self.dcp_size > 1 else None + self.dcp_rank = get_decode_context_model_parallel_rank() if self.dcp_size > 1 else 0 + self.dcp_group = get_dcp_group().device_group if self.dcp_size > 1 else None def get_num_actual_tokens(self, attn_metadata: M): if self.pcp_size > 1: @@ -320,103 +306,80 @@ class AscendMlaCPImpl(AscendMLAImpl): x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) return x - def mla_preprocess_prefill(self, q_c, kv_no_split, kv_cache, - attn_metadata): + def mla_preprocess_prefill(self, q_c, kv_no_split, kv_cache, attn_metadata): if not self.pcp_size > 1: - return super().mla_preprocess_prefill(q_c, kv_no_split, kv_cache, - attn_metadata) + return super().mla_preprocess_prefill(q_c, kv_no_split, kv_cache, attn_metadata) num_decode_tokens = attn_metadata.num_decode_tokens - num_actual_tokens = (attn_metadata.num_actual_tokens_pcp_padded - - self.pcp_size * num_decode_tokens - ) // self.pcp_size + num_decode_tokens + num_actual_tokens = ( + attn_metadata.num_actual_tokens_pcp_padded - self.pcp_size * num_decode_tokens + ) // self.pcp_size + num_decode_tokens prefill_q_c = q_c[num_decode_tokens:num_actual_tokens] - prefill_q = self.q_proj(prefill_q_c)[0] \ - .view(-1, self.num_heads, self.qk_head_dim) - prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:] - prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim] - cos = attn_metadata.prefill.cos[:num_actual_tokens - num_decode_tokens] - sin = attn_metadata.prefill.sin[:num_actual_tokens - num_decode_tokens] + prefill_q = self.q_proj(prefill_q_c)[0].view(-1, self.num_heads, self.qk_head_dim) + prefill_q_pe = prefill_q[..., self.qk_nope_head_dim :] + prefill_q_nope = prefill_q[..., : self.qk_nope_head_dim] + cos = attn_metadata.prefill.cos[: num_actual_tokens - num_decode_tokens] + sin = attn_metadata.prefill.sin[: num_actual_tokens - num_decode_tokens] prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin) prefill_kv_no_split = kv_no_split[:num_actual_tokens] - kv_c, k_pe = prefill_kv_no_split.split( - [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + kv_c, k_pe = prefill_kv_no_split.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) - assert len( - kv_cache - ) > 1, "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)" - kv_c_normed = kv_c_normed.view( - [num_actual_tokens, self.num_kv_heads, -1]) + assert len(kv_cache) > 1, "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)" + kv_c_normed = kv_c_normed.view([num_actual_tokens, self.num_kv_heads, -1]) k_pe = k_pe.unsqueeze(1) prefill_k_pe = k_pe prefill_k_pe[num_decode_tokens:num_actual_tokens] = self.rope_single( - prefill_k_pe[num_decode_tokens:num_actual_tokens], cos, sin) + prefill_k_pe[num_decode_tokens:num_actual_tokens], cos, sin + ) prefill_k_c_normed = kv_c_normed[:num_actual_tokens] - prefill_kv_c_k_pe = torch.cat([prefill_k_c_normed, prefill_k_pe], - dim=-1) + prefill_kv_c_k_pe = torch.cat([prefill_k_c_normed, prefill_k_pe], dim=-1) prefill_kv_c_k_pe = get_pcp_group().all_gather(prefill_kv_c_k_pe, 0) prefill_kv_c_k_pe = torch.index_select( - prefill_kv_c_k_pe, 0, - attn_metadata.prefill.pcp_metadata.pcp_allgather_restore_idx) - prefill_kv_c_k_pe = prefill_kv_c_k_pe[num_decode_tokens * - self.pcp_size:] - prefill_k_c_normed, prefill_k_pe = prefill_kv_c_k_pe.split( - [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + prefill_kv_c_k_pe, 0, attn_metadata.prefill.pcp_metadata.pcp_allgather_restore_idx + ) + prefill_kv_c_k_pe = prefill_kv_c_k_pe[num_decode_tokens * self.pcp_size :] + prefill_k_c_normed, prefill_k_pe = prefill_kv_c_k_pe.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) kv_c_normed, k_pe = prefill_k_c_normed, prefill_k_pe prefill_k_c_normed = prefill_k_c_normed.squeeze() - slot_mapping = attn_metadata.slot_mapping[self.pcp_size * - num_decode_tokens:] - torch_npu._npu_reshape_and_cache(key=kv_c_normed, - value=k_pe, - key_cache=kv_cache[0], - value_cache=kv_cache[1], - slot_indices=slot_mapping) - prefill_k_nope, prefill_value = self.kv_b_proj( - prefill_k_c_normed)[0].view( - -1, self.num_heads, - self.qk_nope_head_dim + self.v_head_dim).split( - [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + slot_mapping = attn_metadata.slot_mapping[self.pcp_size * num_decode_tokens :] + torch_npu._npu_reshape_and_cache( + key=kv_c_normed, value=k_pe, key_cache=kv_cache[0], value_cache=kv_cache[1], slot_indices=slot_mapping + ) + prefill_k_nope, prefill_value = ( + self.kv_b_proj(prefill_k_c_normed)[0] + .view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + ) prefill_k_pe = prefill_k_pe.expand((*prefill_k_nope.shape[:-1], -1)) - return PrefillMLAPreprocessResult(prefill_q_nope, prefill_q_pe, - prefill_k_nope, prefill_k_pe, - prefill_value) + return PrefillMLAPreprocessResult(prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe, prefill_value) def mla_preprocess_decode(self, q_c, kv_no_split, kv_cache, attn_metadata): num_decode_tokens = attn_metadata.num_decode_tokens decode_q_c = q_c[:num_decode_tokens] cos = attn_metadata.decode.cos sin = attn_metadata.decode.sin - decode_ql_nope, decode_q_pe = \ - self._q_proj_and_k_up_proj(decode_q_c) - decode_ql_nope, decode_q_pe = self.reorg_decode_q( - decode_ql_nope, decode_q_pe) + decode_ql_nope, decode_q_pe = self._q_proj_and_k_up_proj(decode_q_c) + decode_ql_nope, decode_q_pe = self.reorg_decode_q(decode_ql_nope, decode_q_pe) decode_q_pe = self.rope_single(decode_q_pe, cos, sin) decode_slots = attn_metadata.slot_mapping[:num_decode_tokens] decode_kv_no_split = kv_no_split[:num_decode_tokens] - decode_k_pe, decode_k_nope = self.exec_kv_decode( - decode_kv_no_split, cos, sin, kv_cache, decode_slots) - return DecodeMLAPreprocessResult(decode_ql_nope, decode_q_pe, - decode_k_nope, decode_k_pe) + decode_k_pe, decode_k_nope = self.exec_kv_decode(decode_kv_no_split, cos, sin, kv_cache, decode_slots) + return DecodeMLAPreprocessResult(decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe) - def get_context_seq_len_npu(self, index: int, - attn_metadata: AscendMLAMetadata): + def get_context_seq_len_npu(self, index: int, attn_metadata: AscendMLAMetadata): prefill_metadata = attn_metadata.prefill assert prefill_metadata is not None assert prefill_metadata.chunked_context is not None - assert isinstance(prefill_metadata.chunked_context, - CPChunkedContextMetadata) + assert isinstance(prefill_metadata.chunked_context, CPChunkedContextMetadata) assert prefill_metadata.chunked_context.padded_chunk_seq_lens_npu is not None iters = len(prefill_metadata.chunked_context.seq_tot) assert 0 <= index < iters - return prefill_metadata.chunked_context.padded_chunk_seq_lens_npu[ - index] + return prefill_metadata.chunked_context.padded_chunk_seq_lens_npu[index] def reorg_decode_q(self, decode_q_nope, decode_q_pe): if self.dcp_size > 1: decode_q_no_split = torch.cat([decode_q_nope, decode_q_pe], dim=-1) - decode_q_no_split = get_dcp_group().all_gather( - decode_q_no_split, 1) - decode_q_nope, decode_q_pe = decode_q_no_split.split( - [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + decode_q_no_split = get_dcp_group().all_gather(decode_q_no_split, 1) + decode_q_nope, decode_q_pe = decode_q_no_split.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) return decode_q_nope, decode_q_pe def _forward_prefill( @@ -426,12 +389,11 @@ class AscendMlaCPImpl(AscendMLAImpl): k_nope: torch.Tensor, k_pe: torch.Tensor, value: torch.Tensor, - kv_c_and_k_pe_cache: Tuple[torch.Tensor], + kv_c_and_k_pe_cache: tuple[torch.Tensor], attn_metadata: AscendMLAMetadata, ) -> torch.Tensor: if not self.pcp_size > 1: - return super()._forward_prefill(q_nope, q_pe, k_nope, k_pe, value, - kv_c_and_k_pe_cache, attn_metadata) + return super()._forward_prefill(q_nope, q_pe, k_nope, k_pe, value, kv_c_and_k_pe_cache, attn_metadata) assert attn_metadata.prefill is not None assert attn_metadata.prefill.pcp_metadata is not None num_tokens = q_nope.size(0) @@ -455,7 +417,8 @@ class AscendMlaCPImpl(AscendMLAImpl): kv_nomask_idx=kv_with_q_head_nomask_idx, attn_mask_seqlens=attn_mask_seqlens, attn_nomask_seqlens=head_attn_nomask_seqlens, - mask=attn_metadata.attn_mask) + mask=attn_metadata.attn_mask, + ) output_tail, lse_tail = self._attention_with_mask_and_nomask( q_nope=torch.index_select(q_nope, 0, q_tail_idx), @@ -467,19 +430,16 @@ class AscendMlaCPImpl(AscendMLAImpl): kv_nomask_idx=kv_with_q_tail_nomask_idx, attn_mask_seqlens=attn_mask_seqlens, attn_nomask_seqlens=tail_attn_nomask_seqlens, - mask=attn_metadata.attn_mask) + mask=attn_metadata.attn_mask, + ) q_full_idx = attn_metadata.prefill.pcp_metadata.q_full_idx - attn_output = torch.index_select( - torch.cat([output_head, output_tail], dim=0), 0, q_full_idx) - attn_lse = torch.index_select(torch.cat([lse_head, lse_tail], dim=1), - 1, q_full_idx) + attn_output = torch.index_select(torch.cat([output_head, output_tail], dim=0), 0, q_full_idx) + attn_lse = torch.index_select(torch.cat([lse_head, lse_tail], dim=1), 1, q_full_idx) - output, _ = self._compute_prefill_context(q_nope, q_pe, - kv_c_and_k_pe_cache, - self.qk_rope_head_dim, - attn_metadata, attn_output, - attn_lse) + output, _ = self._compute_prefill_context( + q_nope, q_pe, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse + ) output = output.reshape([num_tokens, self.num_heads * self.v_head_dim]) @@ -498,44 +458,40 @@ class AscendMlaCPImpl(AscendMLAImpl): attn_nomask_seqlens: list[torch.Tensor], mask: torch.Tensor, ): - attn_output = torch.empty(q_nope.shape[0], - self.num_heads, - self.v_head_dim, - dtype=k_pe.dtype, - device=k_pe.device) - attn_lse = torch.empty(self.num_heads, - q_pe.shape[0], - dtype=torch.float32, - device=k_pe.device) + attn_output = torch.empty( + q_nope.shape[0], self.num_heads, self.v_head_dim, dtype=k_pe.dtype, device=k_pe.device + ) + attn_lse = torch.empty(self.num_heads, q_pe.shape[0], dtype=torch.float32, device=k_pe.device) # mask k_nope_mask = torch.index_select(k_nope, 0, kv_mask_idx) value_mask = torch.index_select(value, 0, kv_mask_idx) k_pe_mask = torch.index_select(k_pe, 0, kv_mask_idx) - torch_npu.atb.npu_ring_mla(q_nope=q_nope, - q_rope=q_pe, - k_nope=k_nope_mask, - k_rope=k_pe_mask, - value=value_mask, - mask=mask, - seqlen=attn_mask_seqlens, - head_num=self.num_heads, - kv_head_num=self.num_heads, - pre_out=None, - prev_lse=None, - qk_scale=self.scale, - kernel_type="kernel_type_high_precision", - mask_type="mask_type_triu", - input_layout="type_bsnd", - calc_type="calc_type_first_ring", - output=attn_output, - softmax_lse=attn_lse) + torch_npu.atb.npu_ring_mla( + q_nope=q_nope, + q_rope=q_pe, + k_nope=k_nope_mask, + k_rope=k_pe_mask, + value=value_mask, + mask=mask, + seqlen=attn_mask_seqlens, + head_num=self.num_heads, + kv_head_num=self.num_heads, + pre_out=None, + prev_lse=None, + qk_scale=self.scale, + kernel_type="kernel_type_high_precision", + mask_type="mask_type_triu", + input_layout="type_bsnd", + calc_type="calc_type_first_ring", + output=attn_output, + softmax_lse=attn_lse, + ) # nomask if not kv_nomask_idx or len(kv_nomask_idx[0]) == 0: return attn_output, attn_lse - for kv_nomask_idx_split, attn_nomask_seqlens_split in zip( - kv_nomask_idx, attn_nomask_seqlens): + for kv_nomask_idx_split, attn_nomask_seqlens_split in zip(kv_nomask_idx, attn_nomask_seqlens): k_nope_nomask = torch.index_select(k_nope, 0, kv_nomask_idx_split) value_nomask = torch.index_select(value, 0, kv_nomask_idx_split) k_pe_nomask = torch.index_select(k_pe, 0, kv_nomask_idx_split) @@ -557,7 +513,8 @@ class AscendMlaCPImpl(AscendMLAImpl): input_layout="type_bsnd", calc_type="calc_type_default", output=attn_output, - softmax_lse=attn_lse) + softmax_lse=attn_lse, + ) return attn_output, attn_lse def _forward_decode( @@ -579,10 +536,8 @@ class AscendMlaCPImpl(AscendMLAImpl): else: num_heads = self.num_heads - k_nope = k_nope.view(-1, block_size, self.num_kv_heads, - self.kv_lora_rank) - k_pe = k_pe.view(-1, block_size, self.num_kv_heads, - self.qk_rope_head_dim) + k_nope = k_nope.view(-1, block_size, self.num_kv_heads, self.kv_lora_rank) + k_pe = k_pe.view(-1, block_size, self.num_kv_heads, self.qk_rope_head_dim) q_nope = q_nope.view(num_tokens, num_heads, -1) q_pe = q_pe.view(num_tokens, num_heads, -1) # use pcp & dcp split computed token nums from scheduler to compute actual seq_len and seq_mask @@ -606,20 +561,35 @@ class AscendMlaCPImpl(AscendMLAImpl): workspace = graph_params.workspaces.get(num_tokens) if workspace is None: workspace = torch_npu.atb._npu_multi_head_latent_attention_get_workspace( - q_nope, q_pe, k_nope, k_pe, decode_meta.block_table, - seq_len, num_heads, self.scale, self.num_kv_heads, - **common_kwargs) + q_nope, + q_pe, + k_nope, + k_pe, + decode_meta.block_table, + seq_len, + num_heads, + self.scale, + self.num_kv_heads, + **common_kwargs, + ) update_graph_params_workspaces(num_tokens, workspace) attn_output = torch.empty_like(q_nope) - softmax_lse = torch.empty((num_tokens, num_heads, 1), - dtype=q_nope.dtype, - device=q_nope.device) + softmax_lse = torch.empty((num_tokens, num_heads, 1), dtype=q_nope.dtype, device=q_nope.device) graph_params.attn_params[num_tokens].append( - (weak_ref_tensors(q_nope), weak_ref_tensors(q_pe), - weak_ref_tensors(k_nope), weak_ref_tensors(k_pe), - decode_meta.block_table, seq_len, num_heads, self.scale, - self.num_kv_heads, weak_ref_tensors(attn_output), - weak_ref_tensors(softmax_lse))) + ( + weak_ref_tensors(q_nope), + weak_ref_tensors(q_pe), + weak_ref_tensors(k_nope), + weak_ref_tensors(k_pe), + decode_meta.block_table, + seq_len, + num_heads, + self.scale, + self.num_kv_heads, + weak_ref_tensors(attn_output), + weak_ref_tensors(softmax_lse), + ) + ) torch.npu.graph_task_group_begin(stream) torch_npu.atb.npu_multi_head_latent_attention( q_nope, @@ -634,14 +604,13 @@ class AscendMlaCPImpl(AscendMLAImpl): **common_kwargs, workspace=workspace, output=attn_output, - lse=softmax_lse) + lse=softmax_lse, + ) handle = torch.npu.graph_task_group_end(stream) graph_params.handles[num_tokens].append(handle) else: attn_output = torch.empty_like(q_nope) - softmax_lse = torch.empty((num_tokens, num_heads, 1), - dtype=q_nope.dtype, - device=q_nope.device) + softmax_lse = torch.empty((num_tokens, num_heads, 1), dtype=q_nope.dtype, device=q_nope.device) torch_npu.atb.npu_multi_head_latent_attention( q_nope, q_pe, @@ -655,20 +624,17 @@ class AscendMlaCPImpl(AscendMLAImpl): return_lse=True, calc_type="calc_type_ring", output=attn_output, - lse=softmax_lse) + lse=softmax_lse, + ) # Update out&lse - attn_out_lse = _process_attn_out_lse(attn_output, softmax_lse, - decode_meta.batch_seq_mask) + attn_out_lse = _process_attn_out_lse(attn_output, softmax_lse, decode_meta.batch_seq_mask) attn_output = _npu_attention_update(self.kv_lora_rank, attn_out_lse) return self._v_up_proj(attn_output) - def _out_lse_reshape(self, attn_out: torch.Tensor, - attn_lse: torch.Tensor) -> torch.Tensor: - attn_out = attn_out.contiguous().view( - attn_out.shape[0] * attn_out.shape[1], attn_out.shape[2]) - attn_lse = attn_lse.contiguous().view( - attn_lse.shape[0] * attn_lse.shape[1] * attn_lse.shape[2]) + def _out_lse_reshape(self, attn_out: torch.Tensor, attn_lse: torch.Tensor) -> torch.Tensor: + attn_out = attn_out.contiguous().view(attn_out.shape[0] * attn_out.shape[1], attn_out.shape[2]) + attn_lse = attn_lse.contiguous().view(attn_lse.shape[0] * attn_lse.shape[1] * attn_lse.shape[2]) return attn_out, attn_lse def _reorg_kvcache( @@ -706,8 +672,7 @@ class AscendMlaCPImpl(AscendMLAImpl): assert chunked_context.max_seq_lens is not None assert chunked_context.chunk_size is not None - padded_local_chunk_seq_lens_lst = chunked_context.padded_local_chunk_seq_lens[ - chunk_idx] + padded_local_chunk_seq_lens_lst = chunked_context.padded_local_chunk_seq_lens[chunk_idx] local_context_lens_allranks = chunked_context.local_context_lens_allranks sum_seq_len = chunked_context.cu_seq_lens_lst[chunk_idx][-1] max_seq_len = chunked_context.max_seq_lens[chunk_idx] @@ -720,14 +685,16 @@ class AscendMlaCPImpl(AscendMLAImpl): cache_kv_c_k_pe = get_pcp_group().all_gather(cache_kv_c_k_pe, 0) allgatered_kv_c_normed, allgatered_k_pe = cache_kv_c_k_pe.split( - [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) kv_c_segments = [] k_pe_segments = [] src_token_idx = 0 max_seq_len_check = 0 for padded_local_chunk_seq_len, local_context_lens in zip( - padded_local_chunk_seq_lens_lst, local_context_lens_allranks): + padded_local_chunk_seq_lens_lst, local_context_lens_allranks + ): cur_seq_len = 0 for rank, local_context_len in enumerate(local_context_lens): # Note(qcs): We split the context into multiple chunks, @@ -742,15 +709,12 @@ class AscendMlaCPImpl(AscendMLAImpl): padded_local_chunk_seq_len, ) if local_chunk_len != 0: - kv_c_segment = allgatered_kv_c_normed[rank * toks + - src_token_idx:rank * - toks + - src_token_idx + - local_chunk_len] - k_pe_segment = allgatered_k_pe[rank * toks + - src_token_idx:rank * toks + - src_token_idx + - local_chunk_len] + kv_c_segment = allgatered_kv_c_normed[ + rank * toks + src_token_idx : rank * toks + src_token_idx + local_chunk_len + ] + k_pe_segment = allgatered_k_pe[ + rank * toks + src_token_idx : rank * toks + src_token_idx + local_chunk_len + ] kv_c_segments.append(kv_c_segment) k_pe_segments.append(k_pe_segment) cur_seq_len += local_chunk_len diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index 619d2278..50be439b 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -1,18 +1,15 @@ from dataclasses import dataclass, field from functools import lru_cache -from typing import Any, List, Optional +from typing import Any import torch import torch.nn.functional as F from vllm.config import VllmConfig, get_current_vllm_config -from vllm.distributed.kv_transfer import (get_kv_transfer_group, - has_kv_transfer_group, - is_v1_kv_transfer_group) +from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group, is_v1_kv_transfer_group from vllm.forward_context import ForwardContext, get_forward_context from vllm.v1.attention.backends.utils import CommonAttentionMetadata -from vllm_ascend.utils import (AscendDeviceType, get_ascend_config, - get_ascend_device_type) +from vllm_ascend.utils import AscendDeviceType, get_ascend_config, get_ascend_device_type def using_paged_attention(runtime_shape: int, vllm_config: VllmConfig) -> bool: @@ -21,6 +18,7 @@ def using_paged_attention(runtime_shape: int, vllm_config: VllmConfig) -> bool: if get_ascend_device_type() == AscendDeviceType.A5: return False from vllm.config.compilation import CUDAGraphMode + cudagraph_mode = vllm_config.compilation_config.cudagraph_mode if cudagraph_mode != CUDAGraphMode.FULL_DECODE_ONLY: return False @@ -31,8 +29,7 @@ def using_paged_attention(runtime_shape: int, vllm_config: VllmConfig) -> bool: @lru_cache(maxsize=1) def enable_cp(): prefill_config = get_current_vllm_config().parallel_config - return prefill_config.prefill_context_parallel_size > 1 \ - or prefill_config.decode_context_parallel_size > 1 + return prefill_config.prefill_context_parallel_size > 1 or prefill_config.decode_context_parallel_size > 1 @dataclass @@ -42,13 +39,14 @@ class AscendPrefillContextParallelMetadata: Contains index tensors and sequence lengths for PCP operations. """ + pcp_allgather_restore_idx: torch.Tensor = None cp_kv_recover_idx_for_chunk: torch.Tensor = None num_actual_tokens_pcp_padded: int = 0 - num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None + num_computed_tokens_of_pcp_dcp: list[list[list[int]]] | None = None q_head_idx_tensor: torch.Tensor = None @@ -85,6 +83,7 @@ class AscendCommonAttentionMetadata(CommonAttentionMetadata): For many of the tensors we keep both NPU and CPU versions. """ + # CPU tensor of sequence lengths for host-side operations. # E.g., tensor([128, 256, 64]) for 3 requests with different seq lengths. seq_lens_cpu: torch.Tensor = None @@ -115,20 +114,17 @@ class AscendCommonAttentionMetadata(CommonAttentionMetadata): num_input_tokens: int = 0 # Metadata for Prefill Context Parallelism (PCP) operations. - prefill_context_parallel_metadata: Optional[ - AscendPrefillContextParallelMetadata] = None + prefill_context_parallel_metadata: AscendPrefillContextParallelMetadata | None = None # TODO: Remove it when vLLM no longer uses this function. - def unpadded(self, num_actual_tokens: int, - num_actual_reqs: int) -> "AscendCommonAttentionMetadata": + def unpadded(self, num_actual_tokens: int, num_actual_reqs: int) -> "AscendCommonAttentionMetadata": # This only use to eagle now. It will be use to enforce_eager in future. return AscendCommonAttentionMetadata( - query_start_loc=self.query_start_loc[:num_actual_reqs + 1], - query_start_loc_cpu=self.query_start_loc_cpu[:num_actual_reqs + 1], + query_start_loc=self.query_start_loc[: num_actual_reqs + 1], + query_start_loc_cpu=self.query_start_loc_cpu[: num_actual_reqs + 1], seq_lens=self.seq_lens[:num_actual_reqs], seq_lens_cpu=self.seq_lens_cpu[:num_actual_reqs], - num_computed_tokens_cpu=self. - num_computed_tokens_cpu[:num_actual_reqs], + num_computed_tokens_cpu=self.num_computed_tokens_cpu[:num_actual_reqs], num_reqs=num_actual_reqs, num_actual_tokens=num_actual_tokens, max_query_len=self.max_query_len, @@ -144,14 +140,14 @@ class AscendCommonAttentionMetadata(CommonAttentionMetadata): attn_state=self.attn_state, graph_pad_size=-1, # It should be -1 when not run in fullgraph mode. num_input_tokens=self.num_input_tokens, - prefill_context_parallel_metadata=self. - prefill_context_parallel_metadata, - max_seq_len=self.max_seq_len) + prefill_context_parallel_metadata=self.prefill_context_parallel_metadata, + max_seq_len=self.max_seq_len, + ) def filter_chunked_req_indices( seq_len: torch.Tensor, - mask_for_non_zero_chunk: Optional[List[bool]], + mask_for_non_zero_chunk: list[bool] | None, ) -> torch.Tensor: """ filter the reqs which are doing real chunk_prefill. @@ -162,14 +158,15 @@ def filter_chunked_req_indices( Returns: filtered_indices: the real chunked req's indices """ - assert mask_for_non_zero_chunk is not None and len(seq_len) == len( - mask_for_non_zero_chunk) + assert mask_for_non_zero_chunk is not None and len(seq_len) == len(mask_for_non_zero_chunk) offsets = torch.cumsum(torch.cat([torch.tensor([0]), seq_len[:-1]]), dim=0) - filtered_indices = torch.cat([ - torch.arange(offsets[i], offsets[i] + seq_len[i]) - for i in range(len(mask_for_non_zero_chunk)) - if mask_for_non_zero_chunk[i] - ]) + filtered_indices = torch.cat( + [ + torch.arange(offsets[i], offsets[i] + seq_len[i]) + for i in range(len(mask_for_non_zero_chunk)) + if mask_for_non_zero_chunk[i] + ] + ) return filtered_indices @@ -195,12 +192,9 @@ def split_decodes_and_prefills( num_prefill_tokens: The number of tokens in the prefill requests. """ long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata - query_lens_pcp_full = long_seq_metadata.query_lens_pcp_full_cpu \ - if long_seq_metadata else None - max_query_len_pcp_full = long_seq_metadata.max_query_len_pcp_full \ - if long_seq_metadata else 0 - max_query_len = common_attn_metadata.max_query_len \ - if max_query_len_pcp_full == 0 else max_query_len_pcp_full + query_lens_pcp_full = long_seq_metadata.query_lens_pcp_full_cpu if long_seq_metadata else None + max_query_len_pcp_full = long_seq_metadata.max_query_len_pcp_full if long_seq_metadata else 0 + max_query_len = common_attn_metadata.max_query_len if max_query_len_pcp_full == 0 else max_query_len_pcp_full num_reqs = common_attn_metadata.num_reqs num_tokens = common_attn_metadata.num_actual_tokens query_start_loc = common_attn_metadata.query_start_loc_cpu @@ -208,8 +202,7 @@ def split_decodes_and_prefills( if max_query_len <= decode_threshold: return num_reqs, 0, num_tokens, 0 - query_lens = (query_start_loc[1:] - query_start_loc[:-1]) \ - if query_lens_pcp_full is None else query_lens_pcp_full + query_lens = (query_start_loc[1:] - query_start_loc[:-1]) if query_lens_pcp_full is None else query_lens_pcp_full is_prefill = query_lens > decode_threshold if not torch.any(is_prefill): return num_reqs, 0, num_tokens, 0 @@ -238,7 +231,7 @@ def wait_for_kv_layer_from_connector(layer_name: str): def maybe_save_kv_layer_to_connector( layer_name: str, - kv_cache_layer: List[torch.Tensor], + kv_cache_layer: list[torch.Tensor], ): if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): return @@ -264,8 +257,7 @@ def trans_rope_weight(weight, rope_dim): return weight.contiguous() nope_part = weight[..., :-rope_dim, :] rope_part = weight[..., -rope_dim:, :] - reordered_rope_part = torch.cat( - (rope_part[..., ::2, :], rope_part[..., 1::2, :]), dim=-2) + reordered_rope_part = torch.cat((rope_part[..., ::2, :], rope_part[..., 1::2, :]), dim=-2) return torch.cat((nope_part, reordered_rope_part), dim=-2).contiguous() @@ -278,12 +270,9 @@ def transdata(nd_mat, block_size: tuple = (16, 16)): nz_mat = torch.permute( torch.reshape( nd_mat, - (r // block_size[0], block_size[0], c // block_size[1], - block_size[1]), + (r // block_size[0], block_size[0], c // block_size[1], block_size[1]), ), [2, 0, 1, 3], ) - nz_mat = torch.reshape( - nz_mat, - (nz_mat.shape[0], nz_mat.shape[1] * nz_mat.shape[2], nz_mat.shape[3])) + nz_mat = torch.reshape(nz_mat, (nz_mat.shape[0], nz_mat.shape[1] * nz_mat.shape[2], nz_mat.shape[3])) return nz_mat diff --git a/vllm_ascend/batch_invariant.py b/vllm_ascend/batch_invariant.py index 50acadb1..fa3e10a9 100644 --- a/vllm_ascend/batch_invariant.py +++ b/vllm_ascend/batch_invariant.py @@ -27,8 +27,12 @@ logger = init_logger(__name__) if HAS_TRITON: from vllm_ascend.ops.triton.batch_invariant.matmul import ( - addmm_batch_invariant, bmm_batch_invariant, linear_batch_invariant, - matmul_batch_invariant, mm_batch_invariant) + addmm_batch_invariant, + bmm_batch_invariant, + linear_batch_invariant, + matmul_batch_invariant, + mm_batch_invariant, + ) def override_envs_for_invariance(): @@ -73,10 +77,11 @@ def init_batch_invariance(): if vllm_is_batch_invariant(): if HAS_TRITON: logger.info( - "Enabling batch-invariant mode for vLLM on Ascend NPU.", ) + "Enabling batch-invariant mode for vLLM on Ascend NPU.", + ) override_envs_for_invariance() enable_batch_invariant_mode() else: logger.warning( - "Batch-invariant mode requested but Triton is not available." - "skipping batch-invariant initialization.", ) + "Batch-invariant mode requested but Triton is not available.skipping batch-invariant initialization.", + ) diff --git a/vllm_ascend/device/device_op.py b/vllm_ascend/device/device_op.py index ccd874dd..92e7e8fa 100644 --- a/vllm_ascend/device/device_op.py +++ b/vllm_ascend/device/device_op.py @@ -15,35 +15,26 @@ # limitations under the License. # This file is a part of the vllm-ascend project. # -from typing import Optional, Type import torch_npu from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type -class BaseDeviceAdaptor(object): - +class BaseDeviceAdaptor: @classmethod - def reshape_and_cache(cls, key, value, key_cache, value_cache, - slot_mapping): - torch_npu._npu_reshape_and_cache(key=key, - value=value, - key_cache=key_cache, - value_cache=value_cache, - slot_indices=slot_mapping) + def reshape_and_cache(cls, key, value, key_cache, value_cache, slot_mapping): + torch_npu._npu_reshape_and_cache( + key=key, value=value, key_cache=key_cache, value_cache=value_cache, slot_indices=slot_mapping + ) class A5DeviceAdaptor(BaseDeviceAdaptor): - @classmethod - def reshape_and_cache(cls, key, value, key_cache, value_cache, - slot_mapping): - torch_npu.npu_scatter_pa_kv_cache(key=key, - value=value.contiguous(), - key_cache=key_cache, - value_cache=value_cache, - slot_mapping=slot_mapping) + def reshape_and_cache(cls, key, value, key_cache, value_cache, slot_mapping): + torch_npu.npu_scatter_pa_kv_cache( + key=key, value=value.contiguous(), key_cache=key_cache, value_cache=value_cache, slot_mapping=slot_mapping + ) def get_device_adaptor(): @@ -53,4 +44,4 @@ def get_device_adaptor(): return BaseDeviceAdaptor -DeviceOperator: Optional[Type['BaseDeviceAdaptor']] = get_device_adaptor() +DeviceOperator: type["BaseDeviceAdaptor"] | None = get_device_adaptor() diff --git a/vllm_ascend/device_allocator/camem.py b/vllm_ascend/device_allocator/camem.py index 1054263e..4971f137 100644 --- a/vllm_ascend/device_allocator/camem.py +++ b/vllm_ascend/device_allocator/camem.py @@ -18,21 +18,22 @@ # import dataclasses import os +from collections.abc import Callable from contextlib import contextmanager -from typing import Any, Callable, Dict, Optional, Tuple, Union +from typing import Any import torch from acl.rt import memcpy # type: ignore # noqa: F401 from vllm.logger import logger -def find_loaded_library(lib_name) -> Optional[str]: +def find_loaded_library(lib_name) -> str | None: """ According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html, the file `/proc/self/maps` contains the memory maps of the process, which includes the shared libraries loaded by the process. We can use this file to find the path of the a loaded library. - """ # noqa + """ # noqa found_line = None with open("/proc/self/maps") as f: for line in f: @@ -47,20 +48,22 @@ def find_loaded_library(lib_name) -> Optional[str]: start = found_line.index("/") path = found_line[start:].strip() filename = path.split("/")[-1] - assert filename.rpartition(".so")[0].startswith(lib_name), \ - f"Unexpected filename: {filename} for library {lib_name}" + assert filename.rpartition(".so")[0].startswith(lib_name), f"Unexpected filename: {filename} for library {lib_name}" return path camem_available = False try: from vllm_ascend.vllm_ascend_C import ( # type: ignore # noqa: F401 - init_module, python_create_and_map, python_unmap_and_release) + init_module, + python_create_and_map, + python_unmap_and_release, + ) + lib_name = find_loaded_library("vllm_ascend_C") camem_available = True except ImportError as e: - logger.warning( - "Failed to import vllm_ascend_C:%s. Sleep mode will be disabled. ", e) + logger.warning("Failed to import vllm_ascend_C:%s. Sleep mode will be disabled. ", e) init_module = None python_create_and_map = None python_unmap_and_release = None @@ -68,14 +71,14 @@ except ImportError as e: libcudart = None # py_device, py_alignedSize, py_d_mem, py_p_memHandle -HandleType = Tuple[int, int, int, int] +HandleType = tuple[int, int, int, int] @dataclasses.dataclass class AllocationData: handle: HandleType tag: str - cpu_backup_tensor: Optional[torch.Tensor] = None + cpu_backup_tensor: torch.Tensor | None = None def create_and_map(allocation_handle: HandleType) -> None: @@ -88,18 +91,18 @@ def unmap_and_release(allocation_handle: HandleType) -> None: def get_pluggable_allocator( python_malloc_fn: Callable[[tuple[int, int, int, int]], None], - python_free_func: Callable[[int], tuple[int, int, int, int]] + python_free_func: Callable[[int], tuple[int, int, int, int]], ) -> torch.npu.memory.NPUPluggableAllocator: init_module(python_malloc_fn, python_free_func) - new_alloc = torch.npu.memory.NPUPluggableAllocator(lib_name, 'my_malloc', - 'my_free') + new_alloc = torch.npu.memory.NPUPluggableAllocator(lib_name, "my_malloc", "my_free") return new_alloc @contextmanager def use_memory_pool_with_allocator( - python_malloc_fn: Callable[[tuple[int, int, int, int]], None], - python_free_func: Callable[[int], tuple[int, int, int, int]]): + python_malloc_fn: Callable[[tuple[int, int, int, int]], None], + python_free_func: Callable[[int], tuple[int, int, int, int]], +): new_alloc = get_pluggable_allocator(python_malloc_fn, python_free_func) mem_pool = torch.npu.memory.MemPool(new_alloc._allocator) with torch.npu.memory.use_mem_pool(mem_pool): @@ -127,6 +130,7 @@ class CaMemAllocator: the global variable will be overwritten and the free callback will not work as expected. """ + instance = None default_tag: str = "default" @@ -143,22 +147,22 @@ class CaMemAllocator: def __init__(self): conf = os.environ.get("PYTORCH_NPU_ALLOC_CONF", "") - assert "expandable_segments:True" not in conf, \ - ("Expandable segments are not compatible with memory pool. " + assert "expandable_segments:True" not in conf, ( + "Expandable segments are not compatible with memory pool. " "Please track https://github.com/pytorch/pytorch/issues/147851 " - "for the latest updates.") + "for the latest updates." + ) - self.pointer_to_data: Dict[int, AllocationData] = {} + self.pointer_to_data: dict[int, AllocationData] = {} self.current_tag: str = CaMemAllocator.default_tag - self.allocator_and_pools: Dict[str, Any] = {} + self.allocator_and_pools: dict[str, Any] = {} def python_malloc_callback(self, allocation_handle: HandleType) -> None: """ Internal method to store the allocation data when memory is allocated in the memory pool.""" py_d_mem = allocation_handle[2] - self.pointer_to_data[py_d_mem] = AllocationData( - allocation_handle, self.current_tag) + self.pointer_to_data[py_d_mem] = AllocationData(allocation_handle, self.current_tag) return def python_free_callback(self, ptr: int) -> HandleType: @@ -170,13 +174,10 @@ class CaMemAllocator: data.cpu_backup_tensor = None return data.handle - def sleep( - self, - offload_tags: Optional[Union[Tuple[str, ...], - str]] = None) -> None: + def sleep(self, offload_tags: tuple[str, ...] | str | None = None) -> None: """ Put the allocator in sleep mode. - All data in the memory allocation with the specified tag will be + All data in the memory allocation with the specified tag will be offloaded to CPU memory, and others will be discarded. :param offload_tags: The tags of the memory allocation that will be offloaded. The rest of the memory allocation will be discarded. @@ -184,9 +185,9 @@ class CaMemAllocator: if offload_tags is None: # by default, allocated tensors are offloaded # when the allocator sleeps - offload_tags = (CaMemAllocator.default_tag, ) + offload_tags = (CaMemAllocator.default_tag,) elif isinstance(offload_tags, str): - offload_tags = (offload_tags, ) + offload_tags = (offload_tags,) assert isinstance(offload_tags, tuple) @@ -194,22 +195,18 @@ class CaMemAllocator: handle = data.handle if data.tag in offload_tags: size_in_bytes = handle[1] - cpu_backup_tensor = torch.empty(size_in_bytes, - dtype=torch.uint8, - device='cpu', - pin_memory=True) + cpu_backup_tensor = torch.empty(size_in_bytes, dtype=torch.uint8, device="cpu", pin_memory=True) cpu_ptr = cpu_backup_tensor.data_ptr() ACL_MEMCPY_DEVICE_TO_HOST = 2 dest_max = cpu_ptr + size_in_bytes * 2 - memcpy(cpu_ptr, dest_max, ptr, size_in_bytes, - ACL_MEMCPY_DEVICE_TO_HOST) + memcpy(cpu_ptr, dest_max, ptr, size_in_bytes, ACL_MEMCPY_DEVICE_TO_HOST) data.cpu_backup_tensor = cpu_backup_tensor unmap_and_release(handle) - def wake_up(self, tags: Optional[list[str]] = None) -> None: + def wake_up(self, tags: list[str] | None = None) -> None: """ Wake up the allocator from sleep mode. - All data that is previously offloaded will be loaded back to GPU + All data that is previously offloaded will be loaded back to GPU memory, and the rest of the data will have empty memory.""" for ptr, data in self.pointer_to_data.items(): if tags is None or data.tag in tags: @@ -218,20 +215,18 @@ class CaMemAllocator: if data.cpu_backup_tensor is not None: cpu_backup_tensor = data.cpu_backup_tensor if cpu_backup_tensor is not None: - size_in_bytes = cpu_backup_tensor.numel( - ) * cpu_backup_tensor.element_size() + size_in_bytes = cpu_backup_tensor.numel() * cpu_backup_tensor.element_size() cpu_ptr = cpu_backup_tensor.data_ptr() ACL_MEMCPY_HOST_TO_DEVICE = 1 dest_max = ptr + size_in_bytes * 2 - memcpy(ptr, dest_max, cpu_ptr, size_in_bytes, - ACL_MEMCPY_HOST_TO_DEVICE) + memcpy(ptr, dest_max, cpu_ptr, size_in_bytes, ACL_MEMCPY_HOST_TO_DEVICE) data.cpu_backup_tensor = None @contextmanager - def use_memory_pool(self, tag: Optional[str] = None): + def use_memory_pool(self, tag: str | None = None): """ A context manager to use the memory pool. - All memory allocation created inside the context will be allocated + All memory allocation created inside the context will be allocated in the memory pool, and has the specified tag. :param tag: The tag of the memory allocation. If None, the default tag will be used. @@ -243,8 +238,7 @@ class CaMemAllocator: old_tag = self.current_tag self.current_tag = tag - with use_memory_pool_with_allocator(self.python_malloc_callback, - self.python_free_callback) as data: + with use_memory_pool_with_allocator(self.python_malloc_callback, self.python_free_callback) as data: # start to hit another PyTorch bug in PyTorch 2.6, # possibly because of gc-related issue w.r.t. the allocator and # the memory pool. diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index bc31abd1..94645b18 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -19,107 +19,89 @@ # import os -from typing import Any, Callable, Dict +from collections.abc import Callable +from typing import Any # The begin-* and end* here are used by the documentation generator # to extract the used env vars. # begin-env-vars-definition -env_variables: Dict[str, Callable[[], Any]] = { +env_variables: dict[str, Callable[[], Any]] = { # max compile thread number for package building. Usually, it is set to # the number of CPU cores. If not set, the default value is None, which # means all number of CPU cores will be used. - "MAX_JOBS": - lambda: os.getenv("MAX_JOBS", None), + "MAX_JOBS": lambda: os.getenv("MAX_JOBS", None), # The build type of the package. It can be one of the following values: # Release, Debug, RelWithDebugInfo. If not set, the default value is Release. - "CMAKE_BUILD_TYPE": - lambda: os.getenv("CMAKE_BUILD_TYPE"), + "CMAKE_BUILD_TYPE": lambda: os.getenv("CMAKE_BUILD_TYPE"), # The CXX compiler used for compiling the package. If not set, the default # value is None, which means the system default CXX compiler will be used. - "CXX_COMPILER": - lambda: os.getenv("CXX_COMPILER", None), + "CXX_COMPILER": lambda: os.getenv("CXX_COMPILER", None), # The C compiler used for compiling the package. If not set, the default # value is None, which means the system default C compiler will be used. - "C_COMPILER": - lambda: os.getenv("C_COMPILER", None), + "C_COMPILER": lambda: os.getenv("C_COMPILER", None), # The version of the Ascend chip. It's used for package building. # If not set, we will query chip info through `npu-smi`. # Please make sure that the version is correct. - "SOC_VERSION": - lambda: os.getenv("SOC_VERSION", None), + "SOC_VERSION": lambda: os.getenv("SOC_VERSION", None), # If set, vllm-ascend will print verbose logs during compilation - "VERBOSE": - lambda: bool(int(os.getenv('VERBOSE', '0'))), + "VERBOSE": lambda: bool(int(os.getenv("VERBOSE", "0"))), # The home path for CANN toolkit. If not set, the default value is # /usr/local/Ascend/ascend-toolkit/latest - "ASCEND_HOME_PATH": - lambda: os.getenv("ASCEND_HOME_PATH", None), + "ASCEND_HOME_PATH": lambda: os.getenv("ASCEND_HOME_PATH", None), # The path for HCCL library, it's used by pyhccl communicator backend. If # not set, the default value is libhccl.so. - "HCCL_SO_PATH": - lambda: os.environ.get("HCCL_SO_PATH", None), + "HCCL_SO_PATH": lambda: os.environ.get("HCCL_SO_PATH", None), # The version of vllm is installed. This value is used for developers who # installed vllm from source locally. In this case, the version of vllm is # usually changed. For example, if the version of vllm is "0.9.0", but when # it's installed from source, the version of vllm is usually set to "0.9.1". # In this case, developers need to set this value to "0.9.0" to make sure # that the correct package is installed. - "VLLM_VERSION": - lambda: os.getenv("VLLM_VERSION", None), + "VLLM_VERSION": lambda: os.getenv("VLLM_VERSION", None), # Whether to enable the model execute time observe profile. Disable it when # running vllm ascend in production environment. - "VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE": - lambda: bool(int(os.getenv("VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE", '0')) - ), + "VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE": lambda: bool( + int(os.getenv("VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE", "0")) + ), # Some models are optimized by vllm ascend. While in some case, e.g. rlhf # training, the optimized model may not be suitable. In this case, set this # value to False to disable the optimized model. - "USE_OPTIMIZED_MODEL": - lambda: bool(int(os.getenv('USE_OPTIMIZED_MODEL', '1'))), + "USE_OPTIMIZED_MODEL": lambda: bool(int(os.getenv("USE_OPTIMIZED_MODEL", "1"))), # Whether to enable MatmulAllReduce fusion kernel when tensor parallel is enabled. # this feature is supported in A2, and eager mode will get better performance. - "VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE": - lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE", '0'))), + "VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE": lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE", "0"))), # Whether to enable FlashComm optimization when tensor parallel is enabled. # This feature will get better performance when concurrency is large. - "VLLM_ASCEND_ENABLE_FLASHCOMM1": - lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM1", '0'))), + "VLLM_ASCEND_ENABLE_FLASHCOMM1": lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM1", "0"))), # Whether to enable FLASHCOMM2. Setting it to 0 disables the feature, while setting it to 1 or above enables it. # The specific value set will be used as the O-matrix TP group size for flashcomm2. # For a detailed introduction to the parameters and the differences and applicable scenarios # between this feature and FLASHCOMM1, please refer to the feature guide in the documentation. - "VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE": - lambda: int(os.getenv("VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE", 0)), + "VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE": lambda: int(os.getenv("VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE", 0)), # Whether to enable MLP weight prefetch, only used in small concurrency. - "VLLM_ASCEND_ENABLE_PREFETCH_MLP": - lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", '0'))), + "VLLM_ASCEND_ENABLE_PREFETCH_MLP": lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", "0"))), # buffer size for gate up prefetch - "VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE": - lambda: int( - os.getenv("VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE", 18 * 1024 * 1024)), + "VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE": lambda: int( + os.getenv("VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE", 18 * 1024 * 1024) + ), # buffer size for down proj prefetch - "VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE": - lambda: int( - os.getenv("VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE", 18 * 1024 * 1024)), + "VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE": lambda: int( + os.getenv("VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE", 18 * 1024 * 1024) + ), # Whether to enable msMonitor tool to monitor the performance of vllm-ascend. - "MSMONITOR_USE_DAEMON": - lambda: bool(int(os.getenv("MSMONITOR_USE_DAEMON", '0'))), - "VLLM_ASCEND_ENABLE_MLAPO": - lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MLAPO", '0'))), + "MSMONITOR_USE_DAEMON": lambda: bool(int(os.getenv("MSMONITOR_USE_DAEMON", "0"))), + "VLLM_ASCEND_ENABLE_MLAPO": lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MLAPO", "0"))), # Whether to enable weight cast format to FRACTAL_NZ. # 0: close nz; # 1: only quant case enable nz; # 2: enable nz as long as possible. - "VLLM_ASCEND_ENABLE_NZ": - lambda: int(os.getenv("VLLM_ASCEND_ENABLE_NZ", 1)), + "VLLM_ASCEND_ENABLE_NZ": lambda: int(os.getenv("VLLM_ASCEND_ENABLE_NZ", 1)), # Decide whether we should enable CP parallelism. - "VLLM_ASCEND_ENABLE_CONTEXT_PARALLEL": - lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_CONTEXT_PARALLEL", '0'))), + "VLLM_ASCEND_ENABLE_CONTEXT_PARALLEL": lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_CONTEXT_PARALLEL", "0"))), # Whether to anbale dynamic EPLB - "DYNAMIC_EPLB": - lambda: os.getenv("DYNAMIC_EPLB", "false").lower(), + "DYNAMIC_EPLB": lambda: os.getenv("DYNAMIC_EPLB", "false").lower(), # Whether to enable fused mc2(`dispatch_gmm_combine_decode`/`dispatch_ffn_combine` operator) # 0, or not set: default ALLTOALL and MC2 will be used. # 1: ALLTOALL and MC2 might be replaced by `dispatch_ffn_combine` operator. @@ -127,11 +109,9 @@ env_variables: Dict[str, Callable[[], Any]] = { # 2: MC2 might be replaced by `dispatch_gmm_combine_decode` operator. # `dispatch_gmm_combine_decode` can be used only for **decode node** moe layer # with W8A8. And MTP layer must be W8A8. - "VLLM_ASCEND_ENABLE_FUSED_MC2": - lambda: int(os.getenv("VLLM_ASCEND_ENABLE_FUSED_MC2", '0')), + "VLLM_ASCEND_ENABLE_FUSED_MC2": lambda: int(os.getenv("VLLM_ASCEND_ENABLE_FUSED_MC2", "0")), # Whether to anbale balance scheduling - "VLLM_ASCEND_BALANCE_SCHEDULING": - lambda: bool(int(os.getenv("VLLM_ASCEND_BALANCE_SCHEDULING", '0'))), + "VLLM_ASCEND_BALANCE_SCHEDULING": lambda: bool(int(os.getenv("VLLM_ASCEND_BALANCE_SCHEDULING", "0"))), } # end-env-vars-definition