Revert "[Perf] Add FIA interface in FA case" (#3553)
Reverts vllm-project/vllm-ascend#3321 The output dimension mismatch and accuracy issue - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: ZYang6263 <zy626375@gmail.com>
This commit is contained in:
@@ -348,50 +348,15 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
mask = torch_npu.npu_format_cast(mask.contiguous(),
|
mask = torch_npu.npu_format_cast(mask.contiguous(),
|
||||||
ACL_FORMAT_FRACTAL_NZ)
|
ACL_FORMAT_FRACTAL_NZ)
|
||||||
|
|
||||||
num_tokens = query.shape[0]
|
torch_npu._npu_flash_attention(query=query,
|
||||||
if torch.version.cann.startswith("8.3") and self.head_size != 256:
|
key=key,
|
||||||
query_start_loc = attn_metadata.actual_seq_lengths_q
|
value=value,
|
||||||
num_tokens = query_start_loc[-1]
|
mask=mask,
|
||||||
softmax_lse = torch.empty(num_tokens,
|
seq_len=attn_metadata.seq_lens,
|
||||||
dtype=query.dtype,
|
scale_value=self.scale,
|
||||||
device=query.device)
|
num_heads=self.num_heads,
|
||||||
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
|
num_kv_heads=self.num_kv_heads,
|
||||||
query=query,
|
out=output)
|
||||||
key=key,
|
|
||||||
value=value,
|
|
||||||
atten_mask=attn_metadata.attn_mask,
|
|
||||||
input_layout="TND",
|
|
||||||
actual_seq_lengths=attn_metadata.actual_seq_lengths_q,
|
|
||||||
actual_seq_lengths_kv=attn_metadata.actual_seq_lengths_q,
|
|
||||||
num_key_value_heads=self.num_kv_heads,
|
|
||||||
num_heads=self.num_heads,
|
|
||||||
scale=self.scale,
|
|
||||||
sparse_mode=3)
|
|
||||||
torch_npu.npu_fused_infer_attention_score.out(
|
|
||||||
query=query,
|
|
||||||
key=key,
|
|
||||||
value=value,
|
|
||||||
atten_mask=attn_metadata.attn_mask,
|
|
||||||
input_layout="TND",
|
|
||||||
actual_seq_lengths=attn_metadata.actual_seq_lengths_q,
|
|
||||||
actual_seq_lengths_kv=attn_metadata.actual_seq_lengths_q,
|
|
||||||
num_key_value_heads=self.num_kv_heads,
|
|
||||||
num_heads=self.num_heads,
|
|
||||||
scale=self.scale,
|
|
||||||
sparse_mode=3,
|
|
||||||
workspace=workspace,
|
|
||||||
out=[output, softmax_lse])
|
|
||||||
|
|
||||||
else:
|
|
||||||
torch_npu._npu_flash_attention(query=query,
|
|
||||||
key=key,
|
|
||||||
value=value,
|
|
||||||
mask=mask,
|
|
||||||
seq_len=attn_metadata.seq_lens,
|
|
||||||
scale_value=self.scale,
|
|
||||||
num_heads=self.num_heads,
|
|
||||||
num_kv_heads=self.num_kv_heads,
|
|
||||||
out=output)
|
|
||||||
assert output is not None
|
assert output is not None
|
||||||
return output[:num_tokens, :, :]
|
return output[:num_tokens, :, :]
|
||||||
|
|
||||||
|
|||||||
@@ -898,12 +898,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
# Prefill without cache situation.
|
# Prefill without cache situation.
|
||||||
elif attn_state == AscendAttentionState.PrefillNoCache:
|
elif attn_state == AscendAttentionState.PrefillNoCache:
|
||||||
if torch.version.cann.startswith("8.3"):
|
max_seq_len = max(seq_lens.max().item(), 0)
|
||||||
return self.attn_mask_builder.get_splitfuse_attn_mask()
|
return self.attn_mask_builder.get_attn_mask(
|
||||||
else:
|
max_seq_len, self.dtype, self.device)
|
||||||
max_seq_len = max(seq_lens, default=0)
|
|
||||||
return self.attn_mask_builder.get_attn_mask(
|
|
||||||
max_seq_len, self.dtype, self.device)
|
|
||||||
# Prefill with cache hit.
|
# Prefill with cache hit.
|
||||||
elif attn_state == AscendAttentionState.PrefillCacheHit:
|
elif attn_state == AscendAttentionState.PrefillCacheHit:
|
||||||
return self.attn_mask_builder.get_attn_mask(
|
return self.attn_mask_builder.get_attn_mask(
|
||||||
|
|||||||
Reference in New Issue
Block a user