diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index ff0240bb..b524e648 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -183,7 +183,6 @@ class AscendMetadataForDecode: class AscendMetadata: # **************************** Basic Properties ************************** # attn_mask: Optional[torch.Tensor] = None - fia_attn_mask: Optional[torch.Tensor] = None # Current state of this attention run. attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill @@ -312,21 +311,18 @@ class AscendAttentionMetadataBuilder: num_actual_tokens_pcp_padded] # slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens] attn_mask = common_attn_metadata.attn_mask - fia_attn_mask = common_attn_metadata.fia_attn_mask attn_state = common_attn_metadata.attn_state query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[: num_reqs + 1] num_computed_tokens_cpu = (seq_lens - query_lens) - if attn_state == AscendAttentionState.DecodeOnly and \ - common_attn_metadata.num_input_tokens > num_actual_tokens: + if common_attn_metadata.num_input_tokens > num_actual_tokens: padded_num_tokens = common_attn_metadata.num_input_tokens - num_actual_tokens seq_lens = torch.cat([ seq_lens, - torch.ones(padded_num_tokens, - dtype=seq_lens.dtype, - device=seq_lens.device) + torch.tensor([padded_num_tokens + ]).to(seq_lens.device).to(seq_lens.dtype) ]) block_table_padding = torch.zeros( (padded_num_tokens, ) + block_table.shape[1:], @@ -335,10 +331,8 @@ class AscendAttentionMetadataBuilder: block_table = torch.cat([block_table, block_table_padding], dim=0) query_start_loc_cpu = torch.cat([ query_start_loc_cpu, - torch.arange(query_start_loc_cpu[-1] + 1, - query_start_loc_cpu[-1] + padded_num_tokens, - dtype=query_start_loc_cpu.dtype, - device=query_start_loc_cpu.device) + torch.tensor([query_start_loc_cpu[-1] + padded_num_tokens]).to( + query_start_loc_cpu.device).to(query_start_loc_cpu.dtype) ]) query_start_loc = query_start_loc_cpu.to(self.device, @@ -471,7 +465,6 @@ class AscendAttentionMetadataBuilder: actual_seq_lengths_q=query_start_loc_cpu[1:].tolist(), slot_mapping=slot_mapping, attn_mask=attn_mask, - fia_attn_mask=fia_attn_mask, attn_state=attn_state, num_prefills=num_prefills, num_decodes=num_decodes, @@ -604,7 +597,6 @@ class AscendAttentionBackendImpl(AttentionImpl): actual_seq_lengths_kv = attn_metadata.seq_lens_list num_tokens = attn_metadata.query_start_loc_list[-1] - query = query[:num_tokens] graph_params = get_graph_params() query_start_loc = attn_metadata.query_start_loc_list # Prepare tensors for attention output @@ -618,7 +610,7 @@ class AscendAttentionBackendImpl(AttentionImpl): query=query, key=key, value=value, - atten_mask=attn_metadata.fia_attn_mask, + atten_mask=attn_metadata.attn_mask, block_table=block_table, input_layout="TND", block_size=block_size, @@ -641,7 +633,7 @@ class AscendAttentionBackendImpl(AttentionImpl): graph_params.attn_params[num_tokens].append( (weak_ref_tensors(query), weak_ref_tensors(key), weak_ref_tensors(value), weak_ref_tensors(block_table), - weak_ref_tensors(attn_metadata.fia_attn_mask), block_size, + weak_ref_tensors(attn_metadata.attn_mask), block_size, actual_seq_lengths_kv, query_start_loc, self.num_kv_heads, self.num_heads, self.scale, weak_ref_tensors(output), weak_ref_tensors(softmax_lse))) @@ -651,7 +643,7 @@ class AscendAttentionBackendImpl(AttentionImpl): query=query, key=key, value=value, - atten_mask=attn_metadata.fia_attn_mask, + atten_mask=attn_metadata.attn_mask, block_table=block_table, input_layout="TND", block_size=block_size, diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 1f46b9d4..9e28e117 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -321,7 +321,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.attn_groups: list[list[AttentionGroup]] = [] self.encoder_cache: Dict[str, torch.Tensor] = {} self.attn_mask = None - self.fia_attn_mask = None self.attn_state = None self.requests: Dict[str, CachedRequestState] = {} self.intermediate_tensors: Optional[IntermediateTensors] = None @@ -982,23 +981,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): # Pooling situation. if self.model_config.runner_type == "pooling" and self.model_config.pooler_config.pooling_type == "CLS": return self.attn_mask_builder.get_pooling_mask(self.device) - # fia prefill situation. - if attn_state in [ - AscendAttentionState.PrefillNoCache, - AscendAttentionState.PrefillCacheHit, - AscendAttentionState.ChunkedPrefill - ]: - return self.attn_mask_builder.get_splitfuse_attn_mask() - - # Decode-only situation. - return None - - def _make_fia_attention_mask(self) -> torch.Tensor: - # pcp situation. - if self.pcp_size > 1: - return None - if self.attn_mask_builder is None: - raise ValueError("Attn mask builder is None") return self.attn_mask_builder.get_splitfuse_attn_mask() def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): @@ -1579,7 +1561,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu, position=positions_cpu, attn_state=attn_state) - self.fia_attn_mask = self._make_fia_attention_mask() self.attn_state = attn_state # type: ignore self.with_prefill = with_prefill @@ -1804,7 +1785,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): num_computed_tokens_cpu=num_computed_tokens_cpu, positions=self.positions, attn_mask=self.attn_mask, - fia_attn_mask=self.fia_attn_mask, spec_attn_mask=self.spec_attn_mask, attn_state=self.attn_state, is_only_prefill=bool(np.all(num_valid_tokens != 1)), @@ -2723,10 +2703,10 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.query_lens = torch.from_numpy(num_scheduled_tokens) assigned_mask_dim = 2048 - self.fia_attn_mask = torch.triu(torch.ones(assigned_mask_dim, - assigned_mask_dim), - diagonal=1).to(torch.int8).to( - self.device) + self.attn_mask = torch.triu(torch.ones(assigned_mask_dim, + assigned_mask_dim), + diagonal=1).to(torch.int8).to( + self.device) num_computed_tokens_cpu = ( self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs]) @@ -2770,7 +2750,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): num_computed_tokens_cpu=num_computed_tokens_cpu, positions=self.positions, attn_mask=self.attn_mask, - fia_attn_mask=self.fia_attn_mask, spec_attn_mask=self.spec_attn_mask, attn_state=self.attn_state, max_query_len=max_query_len,