[Perf] Add FIA interface in FA case (#3321)

### What this PR does / why we need it?

Add new npu_fused_infer_attention_score op to improve perfomance in
flash attention case.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?


- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: ZYang6263 <zy626375@gmail.com>
This commit is contained in:
ZYang6263
2025-10-19 12:45:33 +08:00
committed by GitHub
parent 4b3bd4f397
commit 1e78ecbad6
2 changed files with 50 additions and 12 deletions

View File

@@ -348,15 +348,50 @@ class AscendAttentionBackendImpl(AttentionImpl):
mask = torch_npu.npu_format_cast(mask.contiguous(),
ACL_FORMAT_FRACTAL_NZ)
torch_npu._npu_flash_attention(query=query,
key=key,
value=value,
mask=mask,
seq_len=attn_metadata.seq_lens,
scale_value=self.scale,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
out=output)
num_tokens = query.shape[0]
if torch.version.cann.startswith("8.3") and self.head_size != 256:
query_start_loc = attn_metadata.actual_seq_lengths_q
num_tokens = query_start_loc[-1]
softmax_lse = torch.empty(num_tokens,
dtype=query.dtype,
device=query.device)
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
query=query,
key=key,
value=value,
atten_mask=attn_metadata.attn_mask,
input_layout="TND",
actual_seq_lengths=attn_metadata.actual_seq_lengths_q,
actual_seq_lengths_kv=attn_metadata.actual_seq_lengths_q,
num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale=self.scale,
sparse_mode=3)
torch_npu.npu_fused_infer_attention_score.out(
query=query,
key=key,
value=value,
atten_mask=attn_metadata.attn_mask,
input_layout="TND",
actual_seq_lengths=attn_metadata.actual_seq_lengths_q,
actual_seq_lengths_kv=attn_metadata.actual_seq_lengths_q,
num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale=self.scale,
sparse_mode=3,
workspace=workspace,
out=[output, softmax_lse])
else:
torch_npu._npu_flash_attention(query=query,
key=key,
value=value,
mask=mask,
seq_len=attn_metadata.seq_lens,
scale_value=self.scale,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
out=output)
assert output is not None
return output[:num_tokens, :, :]

View File

@@ -898,9 +898,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# Prefill without cache situation.
elif attn_state == AscendAttentionState.PrefillNoCache:
max_seq_len = max(seq_lens.max().item(), 0)
return self.attn_mask_builder.get_attn_mask(
max_seq_len, self.dtype, self.device)
if torch.version.cann.startswith("8.3"):
return self.attn_mask_builder.get_splitfuse_attn_mask()
else:
max_seq_len = max(seq_lens, default=0)
return self.attn_mask_builder.get_attn_mask(
max_seq_len, self.dtype, self.device)
# Prefill with cache hit.
elif attn_state == AscendAttentionState.PrefillCacheHit:
return self.attn_mask_builder.get_attn_mask(