diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index fce9e8c..156db60 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -348,15 +348,50 @@ class AscendAttentionBackendImpl(AttentionImpl): mask = torch_npu.npu_format_cast(mask.contiguous(), ACL_FORMAT_FRACTAL_NZ) - torch_npu._npu_flash_attention(query=query, - key=key, - value=value, - mask=mask, - seq_len=attn_metadata.seq_lens, - scale_value=self.scale, - num_heads=self.num_heads, - num_kv_heads=self.num_kv_heads, - out=output) + num_tokens = query.shape[0] + if torch.version.cann.startswith("8.3") and self.head_size != 256: + query_start_loc = attn_metadata.actual_seq_lengths_q + num_tokens = query_start_loc[-1] + softmax_lse = torch.empty(num_tokens, + dtype=query.dtype, + device=query.device) + workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace( + query=query, + key=key, + value=value, + atten_mask=attn_metadata.attn_mask, + input_layout="TND", + actual_seq_lengths=attn_metadata.actual_seq_lengths_q, + actual_seq_lengths_kv=attn_metadata.actual_seq_lengths_q, + num_key_value_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale=self.scale, + sparse_mode=3) + torch_npu.npu_fused_infer_attention_score.out( + query=query, + key=key, + value=value, + atten_mask=attn_metadata.attn_mask, + input_layout="TND", + actual_seq_lengths=attn_metadata.actual_seq_lengths_q, + actual_seq_lengths_kv=attn_metadata.actual_seq_lengths_q, + num_key_value_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale=self.scale, + sparse_mode=3, + workspace=workspace, + out=[output, softmax_lse]) + + else: + torch_npu._npu_flash_attention(query=query, + key=key, + value=value, + mask=mask, + seq_len=attn_metadata.seq_lens, + scale_value=self.scale, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + out=output) assert output is not None return output[:num_tokens, :, :] diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 06c52bf..dc21cd1 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -898,9 +898,12 @@ class NPUModelRunner(LoRAModelRunnerMixin): # Prefill without cache situation. elif attn_state == AscendAttentionState.PrefillNoCache: - max_seq_len = max(seq_lens.max().item(), 0) - return self.attn_mask_builder.get_attn_mask( - max_seq_len, self.dtype, self.device) + if torch.version.cann.startswith("8.3"): + return self.attn_mask_builder.get_splitfuse_attn_mask() + else: + max_seq_len = max(seq_lens, default=0) + return self.attn_mask_builder.get_attn_mask( + max_seq_len, self.dtype, self.device) # Prefill with cache hit. elif attn_state == AscendAttentionState.PrefillCacheHit: return self.attn_mask_builder.get_attn_mask(