diff --git a/tests/e2e/multicard/4-cards/long_sequence/test_basic.py b/tests/e2e/multicard/4-cards/long_sequence/test_basic.py index 1af86eb6..db8a26fe 100644 --- a/tests/e2e/multicard/4-cards/long_sequence/test_basic.py +++ b/tests/e2e/multicard/4-cards/long_sequence/test_basic.py @@ -56,6 +56,20 @@ def test_models_pcp_dcp_basic(): quantization="ascend", ) as runner: runner.model.generate(prompts, sampling_params) + + model = "vllm-ascend/DeepSeek-V3.2-W8A8-Pruning" + with VllmRunner( + model, + enforce_eager=True, + max_model_len=1024, + tensor_parallel_size=2, + prefill_context_parallel_size=2, + decode_context_parallel_size=2, + enable_expert_parallel=True, + block_size=128, + quantization="ascend", + ) as runner: + runner.model.generate(prompts, sampling_params) def test_models_pcp_dcp_full_graph(): diff --git a/vllm_ascend/attention/context_parallel/common_cp.py b/vllm_ascend/attention/context_parallel/common_cp.py index 03038b25..8e9517e2 100644 --- a/vllm_ascend/attention/context_parallel/common_cp.py +++ b/vllm_ascend/attention/context_parallel/common_cp.py @@ -26,6 +26,9 @@ class AscendPCPMetadata: tail_attn_nomask_seqlens: torch.Tensor = None q_full_idx: torch.Tensor = None pcp_allgather_restore_idx: list[int] | None = None + block_table_cp: torch.Tensor = None + valid_block_ids: torch.Tensor = None + prefill_q_cum_seqlens: torch.Tensor = None @dataclass diff --git a/vllm_ascend/attention/context_parallel/sfa_cp.py b/vllm_ascend/attention/context_parallel/sfa_cp.py new file mode 100644 index 00000000..7fd3e739 --- /dev/null +++ b/vllm_ascend/attention/context_parallel/sfa_cp.py @@ -0,0 +1,440 @@ +from typing import TypeVar + +import numpy as np +import torch +import torch_npu +from vllm.config import VllmConfig +from vllm.distributed import get_dcp_group, get_pcp_group + +from vllm_ascend.attention.context_parallel.common_cp import AscendPCPMetadata +from vllm_ascend.attention.sfa_v1 import AscendSFAImpl, AscendSFAMetadata, AscendSFAMetadataBuilder +from vllm_ascend.attention.utils import AscendCommonAttentionMetadata, enabling_mlapo, split_decodes_and_prefills + +M = TypeVar("M", bound=AscendSFAMetadata) + + +class AscendSFACPMetadataBuilder(AscendSFAMetadataBuilder): + """ + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + + def __init__( + self, + kv_cache_spec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + metadata_cls: type[AscendSFAMetadata] | None = None, + supports_dcp_with_varlen: bool = False, + ): + super().__init__(kv_cache_spec, layer_names, vllm_config, device, metadata_cls, supports_dcp_with_varlen) + + # In sfa, pcp prefill does not support mlapo + self.enable_mlapo = enabling_mlapo(self.vllm_config) + + self.pcp_size = get_pcp_group().world_size + self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0 + self.pcp_group = get_pcp_group().device_group if self.pcp_size > 1 else None + + self.dcp_size = get_dcp_group().world_size + self.dcp_rank = get_dcp_group().rank_in_group if self.dcp_size > 1 else 0 + self.dcp_group = get_dcp_group().device_group if self.dcp_size > 1 else None + self.cp_local_block_size = vllm_config.parallel_config.cp_kv_cache_interleave_size + self.cp_virtual_block_size = self.cp_local_block_size * self.dcp_size * self.pcp_size + self.block_size = (self.block_size * self.cp_virtual_block_size) // np.gcd( + self.block_size, self.cp_virtual_block_size + ) + + def build( + self, + common_prefix_len: int, + common_attn_metadata: AscendCommonAttentionMetadata, + fast_build: bool = False, + ) -> AscendSFAMetadata: + metadata_cls = super().build(common_prefix_len, common_attn_metadata, fast_build) + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = split_decodes_and_prefills( + common_attn_metadata, decode_threshold=self.decode_threshold + ) + num_reqs = common_attn_metadata.num_reqs + assert num_decodes + num_prefills == num_reqs + assert num_decode_tokens + num_prefill_tokens == common_attn_metadata.num_actual_tokens + + block_table = metadata_cls.block_table + valid_block_ids, new_block_table = block_table.flatten().unique(return_inverse=True) + num_blocks = valid_block_ids.shape[0] + # Note(qcs): `block_table_cp` will have dirty values in the part beyond kv_lens. + # We assume that we can always get the correct kv_lens or kv index, + # so we omit the dirty value processing here. + block_table_cp = ( + new_block_table.unsqueeze(-1).to(block_table) + + (torch.arange(self.pcp_size * self.dcp_size) * num_blocks).view(1, 1, -1).to(block_table) + ).reshape(block_table.shape[0], -1) + + sfa_cp_metadata = self.build_cp_metadata( + block_table_cp, valid_block_ids, metadata_cls.seq_lens, common_attn_metadata + ) + metadata_cls.num_decode_tokens = num_decode_tokens + metadata_cls.num_decodes = num_decodes + metadata_cls.num_prefills = num_prefills + + if self.pcp_size > 1: + long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata + assert long_seq_metadata is not None + num_actual_tokens_pcp_padded = long_seq_metadata.num_actual_tokens_pcp_padded + slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens_pcp_padded] + if self.enable_mlapo: + slot_mapping[:num_decode_tokens] = slot_mapping[: num_decode_tokens * self.pcp_size : self.pcp_size] + slot_mapping[num_decode_tokens : num_decode_tokens * self.pcp_size].fill_(-1) + metadata_cls.slot_mapping = slot_mapping + actual_seq_lengths_query = metadata_cls.cum_query_lens + if num_prefills > 0 and num_decode_tokens > 0: + prefill_q_cum_seqlens = ( + actual_seq_lengths_query[num_decode_tokens:] - actual_seq_lengths_query[num_decode_tokens - 1] + ) + else: + prefill_q_cum_seqlens = actual_seq_lengths_query + assert sfa_cp_metadata is not None + sfa_cp_metadata.prefill_q_cum_seqlens = prefill_q_cum_seqlens + metadata_cls.sfa_cp_metadata = sfa_cp_metadata + return metadata_cls + + def build_cp_metadata( + self, + block_table_cp: torch.Tensor, + valid_block_ids: torch.Tensor, + seq_lens: torch.Tensor, + common_attn_metadata: AscendCommonAttentionMetadata, + ) -> AscendPCPMetadata | None: + common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata + assert common_long_seq_metadata is not None + q_head_kv_lens = (seq_lens // 2) * (self.pcp_rank + 1) + q_tail_kv_lens = seq_lens * self.pcp_size - (seq_lens // 2) * self.pcp_rank + return AscendPCPMetadata( + q_head_idx=common_long_seq_metadata.q_head_idx_tensor, + q_tail_idx=common_long_seq_metadata.q_tail_idx_tensor, + q_full_idx=common_long_seq_metadata.q_full_idx, + head_attn_nomask_seqlens=q_head_kv_lens, + tail_attn_nomask_seqlens=q_tail_kv_lens, + pcp_allgather_restore_idx=common_long_seq_metadata.pcp_allgather_restore_idx, + block_table_cp=block_table_cp, + valid_block_ids=valid_block_ids, + ) + + +class AscendSFACPImpl(AscendSFAImpl): + """ + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: list[float] | None, + sliding_window: int | None, + kv_cache_dtype: str, + logits_soft_cap: float | None, + attn_type: str, + kv_sharing_target_layer_name: str | None, + **kwargs, + ): + super().__init__( + num_heads, + head_size, + scale, + num_kv_heads, + alibi_slopes, + sliding_window, + kv_cache_dtype, + logits_soft_cap, + attn_type, + kv_sharing_target_layer_name, + **kwargs, + ) + # In sfa, pcp prefill does not support mlapo + self.enable_mlapo = enabling_mlapo(self.vllm_config) + self.pcp_size = get_pcp_group().world_size + self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0 + self.pcp_group = get_pcp_group().device_group if self.pcp_size > 1 else None + + self.dcp_size = get_dcp_group().world_size + self.dcp_rank = get_dcp_group().rank_in_group if self.dcp_size > 1 else 0 + self.dcp_group = get_dcp_group().device_group if self.dcp_size > 1 else None + + def _execute_sparse_flash_attention_process( + self, ql_nope, q_pe, kv_cache, topk_indices, attn_metadata, actual_seq_lengths_query, actual_seq_lengths_key + ): + kv = kv_cache[0] + key_rope = kv_cache[1] + + assert attn_metadata.sfa_cp_metadata is not None + valid_block_ids = attn_metadata.sfa_cp_metadata.valid_block_ids + kv = self.gather_kv_cross_cp(kv, valid_block_ids) + key_rope = self.gather_kv_cross_cp(key_rope, valid_block_ids) + block_table = attn_metadata.sfa_cp_metadata.block_table_cp + + if self.pcp_size == 1: + return self._execute_sparse_flash_attention( + ql_nope, q_pe, kv, key_rope, block_table, topk_indices, actual_seq_lengths_query, actual_seq_lengths_key + ) + num_decode_tokens = attn_metadata.num_decode_tokens + num_prefills = attn_metadata.num_prefills + decode_attn_out = None + if num_decode_tokens > 0: + decode_attn_out = self._execute_sparse_flash_attention( + ql_nope[:num_decode_tokens], + q_pe[:num_decode_tokens], + kv, + key_rope, + block_table[:num_decode_tokens], + topk_indices[:num_decode_tokens], + actual_seq_lengths_query[:num_decode_tokens], + actual_seq_lengths_key[:num_decode_tokens], + ) + + if num_prefills < 1: + return decode_attn_out + + # q split for head and tail + q_head_idx = attn_metadata.sfa_cp_metadata.q_head_idx + q_tail_idx = attn_metadata.sfa_cp_metadata.q_tail_idx + ql_nope = ql_nope[num_decode_tokens:] + q_pe = q_pe[num_decode_tokens:] + topk_indices = topk_indices[num_decode_tokens:] + block_table = block_table[num_decode_tokens:] + + # q head compute + q_head_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.head_attn_nomask_seqlens[num_decode_tokens:] + q_head_output = self._execute_sparse_flash_attention( + torch.index_select(ql_nope, 0, q_head_idx), + torch.index_select(q_pe, 0, q_head_idx), + kv, + key_rope, + block_table, + torch.index_select(topk_indices, 0, q_head_idx), + attn_metadata.sfa_cp_metadata.prefill_q_cum_seqlens // 2, + q_head_actual_seq_lengths_key, + ) + + # q tail compute + q_tail_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.tail_attn_nomask_seqlens[num_decode_tokens:] + q_tail_output = self._execute_sparse_flash_attention( + torch.index_select(ql_nope, 0, q_tail_idx), + torch.index_select(q_pe, 0, q_tail_idx), + kv, + key_rope, + block_table, + torch.index_select(topk_indices, 0, q_tail_idx), + attn_metadata.sfa_cp_metadata.prefill_q_cum_seqlens // 2, + q_tail_actual_seq_lengths_key, + ) + + q_full_idx = attn_metadata.sfa_cp_metadata.q_full_idx + attn_output = torch.index_select(torch.cat([q_head_output, q_tail_output], dim=0), 0, q_full_idx) + + if decode_attn_out is not None: + attn_output = torch.cat([decode_attn_out, attn_output], dim=0) + return attn_output + + def _execute_sparse_flash_attention( + self, ql_nope, q_pe, kv, key_rope, block_table, topk_indices, actual_seq_lengths_query, actual_seq_lengths_key + ): + attn_output = torch.ops._C_ascend.npu_sparse_flash_attention( + query=ql_nope, + key=kv, + value=kv, + sparse_indices=topk_indices, + scale_value=self.scale, + sparse_block_size=1, + block_table=block_table, + actual_seq_lengths_query=actual_seq_lengths_query, + actual_seq_lengths_kv=actual_seq_lengths_key, + query_rope=q_pe, + key_rope=key_rope, + layout_query="TND", + layout_kv="PA_BSND", + sparse_mode=3, + ) + return attn_output + + def gather_kv_cross_cp(self, kv_cache: torch.Tensor, valid_block_ids: torch.Tensor) -> torch.Tensor: + # Note(qcs): we need set kv_cache_interleave_size = block_size for sfa!!! + kv_cache = torch.index_select(kv_cache, 0, valid_block_ids) + if self.dcp_size > 1: + kv_cache = get_dcp_group().all_gather(kv_cache, 0) + if self.pcp_size > 1: + kv_cache = get_pcp_group().all_gather(kv_cache, 0) + return kv_cache + + def indexer_select_post_process( + self, + x: torch.Tensor, + qr: torch.Tensor, + q: torch.Tensor | None, + k: torch.Tensor, + kv_cache: tuple[torch.Tensor, torch.Tensor, torch.Tensor], + attn_metadata: M, + cos: torch.Tensor, + sin: torch.Tensor, + actual_seq_lengths_query: torch.Tensor, + actual_seq_lengths_key: torch.Tensor, + need_gather_q_kv: bool = False, + ): + if q is None: + q, _ = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128] + q = q.view(-1, self.n_head, self.head_dim) # [n_toks,64,128] + cos_q, sin_q = cos, sin + + q_pe, q_nope = torch.split( + q, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], dim=-1 + ) # [b,s,64,64+64] + + q_pe = q_pe.unsqueeze(2) + q_pe = torch_npu.npu_rotary_mul(q_pe, cos_q, sin_q) + q_pe = q_pe.squeeze(2) + q = torch.cat([q_pe, q_nope], dim=-1) # [b*s,64,128] + + if kv_cache is not None: + if self.is_kv_producer: + attn_metadata.reshape_cache_event = torch.npu.Event() + torch_npu.npu_scatter_nd_update_( + kv_cache[2].view(-1, k.shape[-1]), attn_metadata.slot_mapping.view(-1, 1), k.view(-1, k.shape[-1]) + ) # b, s, n, d + if self.is_kv_producer: + attn_metadata.reshape_cache_event.record() + + weights, _ = self.weights_proj(x) + weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(weights, need_gather_q_kv) + + key = kv_cache[2] + assert attn_metadata.sfa_cp_metadata is not None + key = self.gather_kv_cross_cp(key, attn_metadata.sfa_cp_metadata.valid_block_ids) + block_table = attn_metadata.sfa_cp_metadata.block_table_cp + + if self.pcp_size == 1: + return self._execute_indexer_select( + q, key, weights, actual_seq_lengths_query, actual_seq_lengths_key, block_table + ) + + # decode compute + num_decode_tokens = attn_metadata.num_decode_tokens + num_prefills = attn_metadata.num_prefills + decode_topk_indices = None + if num_decode_tokens > 0: + decode_topk_indices = self._execute_indexer_select( + q[:num_decode_tokens], + key, + weights[:num_decode_tokens], + actual_seq_lengths_query[:num_decode_tokens], + actual_seq_lengths_key[:num_decode_tokens], + block_table[:num_decode_tokens], + ) + + # prefill compute + if num_prefills == 0: + return decode_topk_indices + q = q[num_decode_tokens:] + weights = weights[num_decode_tokens:] + actual_seq_lengths_key = actual_seq_lengths_key[num_decode_tokens:] + block_table = block_table[num_decode_tokens:] + # pcp split for head and tail + q_head_idx = attn_metadata.sfa_cp_metadata.q_head_idx + q_tail_idx = attn_metadata.sfa_cp_metadata.q_tail_idx + + # q head compute + q_head_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.head_attn_nomask_seqlens[num_decode_tokens:] + q_head_topk_indices = self._execute_indexer_select( + q=torch.index_select(q, 0, q_head_idx), + key=key, + weights=torch.index_select(weights, 0, q_head_idx), + actual_seq_lengths_query=attn_metadata.sfa_cp_metadata.prefill_q_cum_seqlens // 2, + actual_seq_lengths_key=q_head_actual_seq_lengths_key, + block_table=block_table, + ) + + # q tail compute + q_tail_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.tail_attn_nomask_seqlens[num_decode_tokens:] + q_tail_topk_indices = self._execute_indexer_select( + q=torch.index_select(q, 0, q_tail_idx), + key=key, + weights=torch.index_select(weights, 0, q_tail_idx), + actual_seq_lengths_query=attn_metadata.sfa_cp_metadata.prefill_q_cum_seqlens // 2, + actual_seq_lengths_key=q_tail_actual_seq_lengths_key, + block_table=block_table, + ) + + q_full_idx = attn_metadata.sfa_cp_metadata.q_full_idx + topk_indices = torch.index_select(torch.cat([q_head_topk_indices, q_tail_topk_indices], dim=0), 0, q_full_idx) + if decode_topk_indices is not None: + topk_indices = torch.cat([decode_topk_indices, topk_indices], dim=0) + return topk_indices + + def _execute_indexer_select(self, q, key, weights, actual_seq_lengths_query, actual_seq_lengths_key, block_table): + if self.use_torch_npu_lightning_indexer: + topk_indices, _ = torch_npu.npu_lightning_indexer( + query=q, + key=key, + weights=weights, + actual_seq_lengths_query=actual_seq_lengths_query, + actual_seq_lengths_key=actual_seq_lengths_key, + block_table=block_table, + layout_query="TND", + layout_key="PA_BSND", + sparse_count=2048, + sparse_mode=3, + ) + else: + topk_indices = torch.ops._C_ascend.npu_lightning_indexer( + query=q, + key=key, + weights=weights, + actual_seq_lengths_query=actual_seq_lengths_query, + actual_seq_lengths_key=actual_seq_lengths_key, + block_table=block_table, + layout_query="TND", + layout_key="PA_BSND", + sparse_count=2048, + sparse_mode=3, + ) + return topk_indices + + def exec_kv( + self, + kv_no_split: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + kv_cache: tuple, + slots: torch.Tensor, + attn_metadata: M, + ): + if self.pcp_size == 1: + return super().exec_kv(kv_no_split, cos, sin, kv_cache, slots, attn_metadata) + kv_c, k_pe = kv_no_split.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) # type: ignore[misc] + assert len(kv_cache) > 1, "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)" + assert attn_metadata.sfa_cp_metadata is not None + kv_c_normed = kv_c_normed.view([kv_c_normed.shape[0], self.num_kv_heads, -1]) + k_pe = k_pe.unsqueeze(1) + k_pe = self.rope_single(k_pe, cos, sin) + kv_c_k_pe = torch.cat([kv_c_normed, k_pe], dim=-1) + kv_c_k_pe = get_pcp_group().all_gather(kv_c_k_pe, 0) + kv_c_k_pe = torch.index_select(kv_c_k_pe, 0, attn_metadata.sfa_cp_metadata.pcp_allgather_restore_idx) + kv_c_normed, k_pe = kv_c_k_pe.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + slot_mapping = attn_metadata.slot_mapping + torch_npu._npu_reshape_and_cache( + key=kv_c_normed, value=k_pe, key_cache=kv_cache[0], value_cache=kv_cache[1], slot_indices=slot_mapping + ) + return None, None + + def _get_full_kv(self, k, attn_metadata: M): + if self.pcp_size == 1 or self.enable_mlapo: + return k + else: + assert attn_metadata.sfa_cp_metadata is not None + k = get_pcp_group().all_gather(k.contiguous(), 0) + k = torch.index_select(k, 0, attn_metadata.sfa_cp_metadata.pcp_allgather_restore_idx) + return k diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index c3f30e2c..64d21ef2 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -6,7 +6,7 @@ import torch_npu import vllm.envs as envs_vllm from torch import nn from vllm.config import VllmConfig, get_current_vllm_config -from vllm.distributed import get_dcp_group, get_pcp_group, get_tensor_model_parallel_world_size, get_tp_group +from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group from vllm.forward_context import get_forward_context from vllm.logger import logger from vllm.model_executor.layers.attention.mla_attention import MLACommonMetadataBuilder @@ -19,10 +19,12 @@ from vllm_ascend import envs from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.attention.context_parallel.common_cp import AscendPCPMetadata from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE, MLAPO_MAX_SUPPORTED_TOKENS from vllm_ascend.attention.utils import ( AscendCommonAttentionMetadata, ascend_chunked_prefill_workspace_size, + enable_cp, maybe_save_kv_layer_to_connector, trans_rope_weight, transdata, @@ -68,6 +70,10 @@ class AscendSFABackend(AttentionBackend): @staticmethod def get_builder_cls(): + if enable_cp(): + from vllm_ascend.attention.context_parallel.sfa_cp import AscendSFACPMetadataBuilder + + return AscendSFACPMetadataBuilder return AscendSFAMetadataBuilder @staticmethod @@ -76,6 +82,10 @@ class AscendSFABackend(AttentionBackend): @staticmethod def get_impl_cls() -> type["AscendSFAImpl"]: + if enable_cp(): + from vllm_ascend.attention.context_parallel.sfa_cp import AscendSFACPImpl + + return AscendSFACPImpl return AscendSFAImpl @staticmethod @@ -95,12 +105,6 @@ class DSACPContext: actual_seq_lengths_key: torch.Tensor -@dataclass -class SFACPMetadata: - block_table_cp: torch.Tensor - valid_block_ids: torch.Tensor - - @dataclass class AscendSFAMetadata: """Metadata for MLACommon. @@ -133,7 +137,10 @@ class AscendSFAMetadata: attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill dsa_cp_context: DSACPContext | None = None reshape_cache_event: torch.npu.Event = None - sfa_cp_metadata: SFACPMetadata | None = None + sfa_cp_metadata: AscendPCPMetadata | None = None + num_decodes: int = 0 + num_decode_tokens: int = 0 + num_prefills: int = 0 M = TypeVar("M", bound=AscendSFAMetadata) @@ -185,14 +192,6 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]): self.actual_seq_lengths_query = torch.zeros(max_num_reqs + 1, dtype=torch.int32, device=device) self.actual_seq_lengths_key = torch.empty_like(self.actual_seq_lengths_query) - self.pcp_size = get_pcp_group().world_size - self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0 - self.pcp_group = get_pcp_group().device_group if self.pcp_size > 1 else None - - self.dcp_size = get_dcp_group().world_size - self.dcp_rank = get_dcp_group().rank_in_group if self.dcp_size > 1 else 0 - self.dcp_group = get_dcp_group().device_group if self.dcp_size > 1 else None - @staticmethod def determine_chunked_prefill_workspace_size(vllm_config: VllmConfig) -> int: return ascend_chunked_prefill_workspace_size(vllm_config) @@ -309,22 +308,6 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]): actual_seq_lengths_key=actual_seq_lengths_key, ) - sfa_cp_metadata = None - if self.pcp_size * self.dcp_size > 1: - valid_block_ids, new_block_table = block_table.flatten().unique(return_inverse=True) - num_blocks = valid_block_ids.shape[0] - # Note(qcs): `block_table_cp` will have dirty values in the part beyond kv_lens. - # We assume that we can always get the correct kv_lens or kv index, - # so we omit the dirty value processing here. - block_table_cp = ( - new_block_table.unsqueeze(-1).to(block_table) - + (torch.arange(self.pcp_size * self.dcp_size) * num_blocks).view(1, 1, -1).to(block_table) - ).reshape(block_table.shape[0], -1) - sfa_cp_metadata = SFACPMetadata( - block_table_cp=block_table_cp, - valid_block_ids=valid_block_ids, - ) - return self.metadata_cls( # type: ignore num_input_tokens=common_attn_metadata.num_input_tokens, num_actual_tokens=num_actual_tokens, @@ -338,7 +321,6 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]): sin=sin[:num_input_tokens], cos=cos[:num_input_tokens], dsa_cp_context=dsa_cp_context, - sfa_cp_metadata=sfa_cp_metadata, ) def build_for_graph_capture( @@ -453,14 +435,6 @@ class AscendSFAImpl(MLAAttentionImpl): ) register_all_layers_to_shard_weight_series(self.layer_sharding_kwargs) - self.pcp_size = get_pcp_group().world_size - self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0 - self.pcp_group = get_pcp_group().device_group if self.pcp_size > 1 else None - - self.dcp_size = get_dcp_group().world_size - self.dcp_rank = get_dcp_group().rank_in_group if self.dcp_size > 1 else 0 - self.dcp_group = get_dcp_group().device_group if self.dcp_size > 1 else None - def process_weights_after_loading(self, act_dtype: torch.dtype): # NOTE: We currently do not support quant kv_b_proj. assert isinstance(self.kv_b_proj.quant_method, UnquantizedLinearMethod) @@ -562,6 +536,9 @@ class AscendSFAImpl(MLAAttentionImpl): # Convert from (N, B, L) to (B, N, L) return ql_nope.transpose(0, 1), q_pe + def _get_full_kv(self, k, attn_metadata): + return k + def exec_kv( self, kv_no_split: torch.Tensor, @@ -569,6 +546,7 @@ class AscendSFAImpl(MLAAttentionImpl): sin: torch.Tensor, kv_cache: tuple, slots: torch.Tensor, + attn_metadata: M, ): B = kv_no_split.shape[0] N = self.num_kv_heads @@ -835,7 +813,7 @@ class AscendSFAImpl(MLAAttentionImpl): actual_seq_lengths_query = attn_metadata.dsa_cp_context.actual_seq_lengths_query actual_seq_lengths_key = attn_metadata.dsa_cp_context.actual_seq_lengths_key - k_pe, k_nope = self.exec_kv(kv_no_split, cos, sin, kv_cache, slot_mapping) + k_pe, k_nope = self.exec_kv(kv_no_split, cos, sin, kv_cache, slot_mapping, attn_metadata) if self.enable_dsa_cp: assert k_pe is not None @@ -875,6 +853,7 @@ class AscendSFAImpl(MLAAttentionImpl): torch_npu.npu_scatter_nd_update_(kv_cache[0].view(-1, k_nope.shape[-1]), slot_mapping, k_nope) torch_npu.npu_scatter_nd_update_(kv_cache[1].view(-1, k_pe.shape[-1]), slot_mapping, k_pe) + k = self._get_full_kv(k, attn_metadata) if kv_cache is not None: torch_npu.npu_scatter_nd_update_( kv_cache[2].view(-1, k.shape[-1]), attn_metadata.slot_mapping.view(-1, 1), k.view(-1, k.shape[-1]) @@ -894,31 +873,8 @@ class AscendSFAImpl(MLAAttentionImpl): need_gather_q_kv=need_gather_q_kv, ) - block_table = attn_metadata.block_table - kv = kv_cache[0] - key_rope = kv_cache[1] - if self.pcp_size * self.dcp_size > 1: - assert attn_metadata.sfa_cp_metadata is not None - valid_block_ids = attn_metadata.sfa_cp_metadata.valid_block_ids - kv = self.gather_kv_cross_cp(kv, valid_block_ids) - key_rope = self.gather_kv_cross_cp(key_rope, valid_block_ids) - block_table = attn_metadata.sfa_cp_metadata.block_table_cp - - attn_output = torch.ops._C_ascend.npu_sparse_flash_attention( - query=ql_nope, - key=kv, - value=kv, - sparse_indices=topk_indices, - scale_value=self.scale, - sparse_block_size=1, - block_table=block_table, - actual_seq_lengths_query=actual_seq_lengths_query, - actual_seq_lengths_kv=actual_seq_lengths_key, - query_rope=q_pe, - key_rope=key_rope, - layout_query="TND", - layout_kv="PA_BSND", - sparse_mode=3, + attn_output = self._execute_sparse_flash_attention_process( + ql_nope, q_pe, kv_cache, topk_indices, attn_metadata, actual_seq_lengths_query, actual_seq_lengths_key ) attn_output = self._v_up_proj(attn_output) @@ -950,14 +906,30 @@ class AscendSFAImpl(MLAAttentionImpl): return output_padded - def gather_kv_cross_cp(self, kv_cache: torch.Tensor, valid_block_ids: torch.Tensor) -> torch.Tensor: - # Note(qcs): we need set kv_cache_interleave_size = block_size for sfa!!! - kv_cache = torch.index_select(kv_cache, 0, valid_block_ids) - if self.dcp_size > 1: - kv_cache = get_dcp_group().all_gather(kv_cache, 0) - if self.pcp_size > 1: - kv_cache = get_pcp_group().all_gather(kv_cache, 0) - return kv_cache + def _execute_sparse_flash_attention_process( + self, ql_nope, q_pe, kv_cache, topk_indices, attn_metadata, actual_seq_lengths_query, actual_seq_lengths_key + ): + block_table = attn_metadata.block_table + kv = kv_cache[0] + key_rope = kv_cache[1] + + attn_output = torch.ops._C_ascend.npu_sparse_flash_attention( + query=ql_nope, + key=kv, + value=kv, + sparse_indices=topk_indices, + scale_value=self.scale, + sparse_block_size=1, + block_table=block_table, + actual_seq_lengths_query=actual_seq_lengths_query, + actual_seq_lengths_kv=actual_seq_lengths_key, + query_rope=q_pe, + key_rope=key_rope, + layout_query="TND", + layout_kv="PA_BSND", + sparse_mode=3, + ) + return attn_output def indexer_select_pre_process( self, @@ -1038,10 +1010,6 @@ class AscendSFAImpl(MLAAttentionImpl): key = kv_cache[2] block_table = attn_metadata.block_table - if self.pcp_size * self.dcp_size > 1: - assert attn_metadata.sfa_cp_metadata is not None - key = self.gather_kv_cross_cp(key, attn_metadata.sfa_cp_metadata.valid_block_ids) - block_table = attn_metadata.sfa_cp_metadata.block_table_cp # DSV3.2 currently has graph compilation issues when using torch_npu.npu.lightning_indexer. # So two branches are maintained temporarily.