[Perf] Add FIA interface in FA case (#3321)
### What this PR does / why we need it? Add new npu_fused_infer_attention_score op to improve perfomance in flash attention case. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: ZYang6263 <zy626375@gmail.com>
This commit is contained in:
@@ -348,15 +348,50 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
mask = torch_npu.npu_format_cast(mask.contiguous(),
|
||||
ACL_FORMAT_FRACTAL_NZ)
|
||||
|
||||
torch_npu._npu_flash_attention(query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
mask=mask,
|
||||
seq_len=attn_metadata.seq_lens,
|
||||
scale_value=self.scale,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
out=output)
|
||||
num_tokens = query.shape[0]
|
||||
if torch.version.cann.startswith("8.3") and self.head_size != 256:
|
||||
query_start_loc = attn_metadata.actual_seq_lengths_q
|
||||
num_tokens = query_start_loc[-1]
|
||||
softmax_lse = torch.empty(num_tokens,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
atten_mask=attn_metadata.attn_mask,
|
||||
input_layout="TND",
|
||||
actual_seq_lengths=attn_metadata.actual_seq_lengths_q,
|
||||
actual_seq_lengths_kv=attn_metadata.actual_seq_lengths_q,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale=self.scale,
|
||||
sparse_mode=3)
|
||||
torch_npu.npu_fused_infer_attention_score.out(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
atten_mask=attn_metadata.attn_mask,
|
||||
input_layout="TND",
|
||||
actual_seq_lengths=attn_metadata.actual_seq_lengths_q,
|
||||
actual_seq_lengths_kv=attn_metadata.actual_seq_lengths_q,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale=self.scale,
|
||||
sparse_mode=3,
|
||||
workspace=workspace,
|
||||
out=[output, softmax_lse])
|
||||
|
||||
else:
|
||||
torch_npu._npu_flash_attention(query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
mask=mask,
|
||||
seq_len=attn_metadata.seq_lens,
|
||||
scale_value=self.scale,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
out=output)
|
||||
assert output is not None
|
||||
return output[:num_tokens, :, :]
|
||||
|
||||
|
||||
@@ -898,9 +898,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
# Prefill without cache situation.
|
||||
elif attn_state == AscendAttentionState.PrefillNoCache:
|
||||
max_seq_len = max(seq_lens.max().item(), 0)
|
||||
return self.attn_mask_builder.get_attn_mask(
|
||||
max_seq_len, self.dtype, self.device)
|
||||
if torch.version.cann.startswith("8.3"):
|
||||
return self.attn_mask_builder.get_splitfuse_attn_mask()
|
||||
else:
|
||||
max_seq_len = max(seq_lens, default=0)
|
||||
return self.attn_mask_builder.get_attn_mask(
|
||||
max_seq_len, self.dtype, self.device)
|
||||
# Prefill with cache hit.
|
||||
elif attn_state == AscendAttentionState.PrefillCacheHit:
|
||||
return self.attn_mask_builder.get_attn_mask(
|
||||
|
||||
Reference in New Issue
Block a user