diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 4f316097..8362e195 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -166,10 +166,12 @@ class AscendMetadataForPrefill: actual_chunk_seq_lengths: list[int] actual_seq_lengths_kv: list[int] starts: torch.Tensor + chunk_seq_mask_filtered_indices: torch.Tensor chunked_req_mask: Optional[list[bool]] = None local_context_lens_allranks: Optional[list[list[int]]] = None cp_kv_recover_idx_for_chunk: Optional[list[int]] = None kv_inverse_idx_for_chunk: Optional[list[int]] = None + batch_chunk_seq_mask: Optional[list[bool]] = None """ Prefill Specific Metadata for Ascend""" pcp_metadata: Optional[AscendPCPMetadata] = None @@ -401,6 +403,14 @@ class AscendAttentionMetadataBuilder: cp_kv_recover_idx_for_chunk.to(torch.float32) ) if cp_kv_recover_idx_for_chunk is not None else None + batch_chunk_seq_mask = ( + local_context_lens_allranks[:, self.pcp_rank, + self.dcp_rank] == 0) + batch_chunk_seq_mask = torch.repeat_interleave( + batch_chunk_seq_mask, + repeats=(query_lens * self.pcp_size).to(self.device)) + chunk_seq_mask_filtered_indices = filter_chunked_req_indices( + query_lens, chunked_req_mask).to(self.device) chunked_context_metadata = \ AscendMetadataForPrefill.ChunkedContextMetadata( actual_chunk_seq_lengths=torch.cumsum(query_lens * pcp_size, dim=0), @@ -409,7 +419,9 @@ class AscendAttentionMetadataBuilder: starts=local_chunk_starts, local_context_lens_allranks=local_context_lens_allranks, cp_kv_recover_idx_for_chunk=cp_kv_recover_idx_for_chunk, - kv_inverse_idx_for_chunk=kv_inverse_idx_for_chunk + kv_inverse_idx_for_chunk=kv_inverse_idx_for_chunk, + batch_chunk_seq_mask=batch_chunk_seq_mask, + chunk_seq_mask_filtered_indices=chunk_seq_mask_filtered_indices ) attn_mask_seqlens = common_long_seq_metadata.attn_mask_seqlens head_attn_nomask_seqlens = common_long_seq_metadata.head_attn_nomask_seqlens @@ -571,10 +583,15 @@ class AscendAttentionBackendImpl(AttentionImpl): query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + kv_cache: Tuple[torch.Tensor], attn_metadata: AscendMetadata, output: torch.Tensor, num_tokens=0): - if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: + if self.pcp_size * self.dcp_size > 1: + intermediate_output = self._forward_pcp_dcp( + query, key, value, kv_cache, attn_metadata, output) + return intermediate_output, query.shape[0] + elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: block_size = 128 block_table = None actual_seq_lengths_kv = attn_metadata.query_start_loc_list @@ -1276,9 +1293,7 @@ class AscendAttentionBackendImpl(AttentionImpl): self.pcp_rank * num_tokens:(self.pcp_rank + 1) * num_tokens, :, :] assert attn_output_full_chunk.shape == current_attn_output_prefill.shape and attn_lse_full_chunk.shape == current_attn_lse_prefill.shape - seq_len = attn_metadata.query_lens.detach().clone() - filtered_indices = filter_chunked_req_indices( - seq_len, attn_metadata.prefill.chunked_context.chunked_req_mask) + filtered_indices = attn_metadata.prefill.chunked_context.chunk_seq_mask_filtered_indices attn_output_prefill_filtered = current_attn_output_prefill[ filtered_indices, :, :] @@ -1322,9 +1337,11 @@ class AscendAttentionBackendImpl(AttentionImpl): local_chunked_kv_lens_rank = local_chunked_kv_lens[:, self.pcp_rank, self.dcp_rank] + total_toks = local_chunked_kv_lens_rank.sum() key, value = self._load_kv_for_chunk(attn_metadata, kv_cache, - local_chunked_kv_lens_rank, query) + local_chunked_kv_lens_rank, query, + total_toks) if self.dcp_size > 1: num_heads = self.num_heads * self.dcp_size else: @@ -1340,7 +1357,7 @@ class AscendAttentionBackendImpl(AttentionImpl): dtype=torch.float32, device=query.device) - if not torch.all(local_chunked_kv_lens_rank == 0).item(): + if total_toks > 0: prefix_chunk_output, prefix_chunk_lse = torch.ops.npu.npu_fused_infer_attention_score( query, key, @@ -1358,6 +1375,14 @@ class AscendAttentionBackendImpl(AttentionImpl): actual_seq_lengths_kv, actual_seq_lengths=attn_metadata.prefill.chunked_context. actual_chunk_seq_lengths) + batch_chunk_seq_mask = attn_metadata.prefill.chunked_context.batch_chunk_seq_mask + out_mask = batch_chunk_seq_mask[:, None, None].expand_as( + prefix_chunk_output) + prefix_chunk_output = torch.where(out_mask, 0, prefix_chunk_output) + lse_mask = batch_chunk_seq_mask[:, None, + None].expand_as(prefix_chunk_lse) + prefix_chunk_lse = torch.where(lse_mask, -torch.inf, + prefix_chunk_lse) prefix_output, prefix_lse = self._update_chunk_attn_out_lse( prefix_chunk_output, prefix_chunk_lse) @@ -1413,14 +1438,12 @@ class AscendAttentionBackendImpl(AttentionImpl): return prefix_output, prefix_lse def _load_kv_for_chunk(self, attn_metadata, kv_cache, - local_chunked_kv_lens_rank, query): + local_chunked_kv_lens_rank, query, total_toks): cache_key = kv_cache[0] cache_value = kv_cache[1] num_heads = cache_key.size(2) head_size = kv_cache[0].size(-1) - total_toks = local_chunked_kv_lens_rank.sum() - key = torch.empty(total_toks, num_heads, head_size, @@ -1579,7 +1602,7 @@ class AscendAttentionBackendImpl(AttentionImpl): query, attn_metadata, output) else: intermediate_output, num_tokens = self.full_graph_attention( - query, key, value, attn_metadata, output) + query, key, value, kv_cache, attn_metadata, output) output[:num_tokens] = intermediate_output[:num_tokens] return output diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 2b0481c0..817eaa3b 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -294,21 +294,23 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.scheduler_config = vllm_config.scheduler_config self.speculative_config = vllm_config.speculative_config self.block_size = vllm_config.cache_config.block_size - self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len, - self.block_size) - self.max_num_tokens = self.scheduler_config.max_num_batched_tokens - decode_max_num_seqs = getattr(self.scheduler_config, - 'decode_max_num_seqs', 0) - self.max_num_reqs = max(self.scheduler_config.max_num_seqs, - decode_max_num_seqs) self.dp_size = vllm_config.parallel_config.data_parallel_size self.dp_rank = vllm_config.parallel_config.data_parallel_rank + self.dcp_size = get_dcp_group().world_size + self.dcp_rank = get_dcp_group().rank_in_group self.pcp_size = get_prefill_context_model_parallel_world_size( ) if prefill_context_parallel_enable() else 1 self.pcp_rank = get_prefill_context_model_parallel_rank( ) if self.pcp_size > 1 else 0 - self.dcp_size = get_dcp_group().world_size - self.dcp_rank = get_dcp_group().rank_in_group + decode_max_num_seqs = getattr(self.scheduler_config, + 'decode_max_num_seqs', 0) + self.max_num_reqs = max(self.scheduler_config.max_num_seqs, + decode_max_num_seqs) + if self.pcp_size > 1: + self.model_config.max_model_len += 2 * self.pcp_size * self.max_num_reqs + self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len, + self.block_size) + self.max_num_tokens = self.scheduler_config.max_num_batched_tokens self.device = device if envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP: self.prefetch_stream = torch.npu.Stream(device=device) @@ -1007,10 +1009,12 @@ class NPUModelRunner(LoRAModelRunnerMixin): def _make_attention_mask(self, seq_lens, position, attn_state) -> torch.Tensor: + # pcp situation. if self.pcp_size > 1: return None if self.attn_mask_builder is None: raise ValueError("Attn mask builder is None") + # dcp situation. if self.dcp_size > 1: return self.attn_mask_builder.get_splitfuse_attn_mask() # Pooling situation. @@ -1018,12 +1022,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): return self.attn_mask_builder.get_pooling_mask(self.device) # Chunk Prefill situation. elif attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla and not self.use_sparse: - if self.dcp_size > 1: - max_seq_len = max(seq_lens.max().item(), 0) - return self.attn_mask_builder.get_attn_mask( - max_seq_len, self.dtype, self.device) - else: - return self.attn_mask_builder.get_splitfuse_attn_mask() + return self.attn_mask_builder.get_splitfuse_attn_mask() # Prefill without cache situation. elif attn_state == AscendAttentionState.PrefillNoCache: @@ -1039,6 +1038,9 @@ class NPUModelRunner(LoRAModelRunnerMixin): return None def _make_fia_attention_mask(self) -> torch.Tensor: + # pcp situation. + if self.pcp_size > 1: + return None if self.attn_mask_builder is None: raise ValueError("Attn mask builder is None") return self.attn_mask_builder.get_splitfuse_attn_mask()