Remove chunked_prefill_for_mla and fix ring_mla bug (#2781)
### What this PR does / why we need it?
Remove chunked prefill for mla branch in mla , and change dtype of
prefill_mask to avoid accuracy problem
### Does this PR introduce _any_ user-facing change?
NO
### How was this patch tested?
- vLLM version: v0.10.2
- vLLM main:
ef7eefe17a
---------
Signed-off-by: SunnyLee219 <3294305115@qq.com>
This commit is contained in:
@@ -20,7 +20,6 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
||||
from vllm_ascend.multistream.context import get_multistream_comm_context
|
||||
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
||||
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
|
||||
from vllm_ascend.utils import npu_prefetch
|
||||
from vllm_ascend.worker.npu_input_batch import InputBatch
|
||||
|
||||
@@ -491,7 +490,6 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||
self.enable_prefetch = ascend_config.enable_prefetch
|
||||
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
|
||||
self.chunked_prefill_for_mla = ascend_config.chunked_prefill_for_mla
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.ring_mla_mask_size = 512
|
||||
@@ -673,84 +671,47 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
self.v_head_dim,
|
||||
dtype=q_nope.dtype,
|
||||
device=q_nope.device)
|
||||
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||
query = torch.cat((q_nope, q_pe), dim=-1)
|
||||
key = torch.cat((k_nope, k_pe), dim=-1)
|
||||
torch_npu._npu_flash_attention(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
mask=attn_metadata.attn_mask,
|
||||
seq_len=attn_metadata.prefill.context_lens,
|
||||
scale_value=self.scale,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_heads,
|
||||
out=attn_output)
|
||||
elif self.chunked_prefill_for_mla:
|
||||
attn_lse = torch.empty(self.num_heads,
|
||||
num_tokens,
|
||||
dtype=torch.float32,
|
||||
device=q_nope.device)
|
||||
if self.prefill_mask is None:
|
||||
self.prefill_mask = torch.triu(
|
||||
torch.ones(self.ring_mla_mask_size,
|
||||
self.ring_mla_mask_size,
|
||||
device=q_nope.device,
|
||||
dtype=q_nope.dtype), 1)
|
||||
torch_npu.atb.npu_ring_mla(
|
||||
q_nope=q_nope,
|
||||
q_rope=q_pe,
|
||||
k_nope=k_nope,
|
||||
k_rope=k_pe,
|
||||
value=value,
|
||||
mask=self.prefill_mask,
|
||||
seqlen=torch.tensor(attn_metadata.prefill.query_lens,
|
||||
dtype=torch.int32),
|
||||
head_num=self.num_heads,
|
||||
kv_head_num=self.num_heads,
|
||||
pre_out=None,
|
||||
prev_lse=None,
|
||||
qk_scale=self.scale,
|
||||
kernel_type="kernel_type_high_precision",
|
||||
mask_type="mask_type_triu",
|
||||
input_layout="type_bsnd",
|
||||
calc_type="calc_type_first_ring",
|
||||
output=attn_output,
|
||||
softmax_lse=attn_lse)
|
||||
attn_output, attn_lse = self._compute_prefill_context( \
|
||||
q_nope, q_pe, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse)
|
||||
else:
|
||||
query = torch.cat((q_nope, q_pe), dim=-1)
|
||||
attn_output_torch = torch.empty(num_tokens,
|
||||
self.num_heads * self.v_head_dim,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
# current requests is chunked in prefill, disable flash attention with chunked prefill
|
||||
vanilla_chunked_prefill_mla(
|
||||
output=attn_output_torch,
|
||||
query=query,
|
||||
kv_cache=kv_c_and_k_pe_cache,
|
||||
block_tables=attn_metadata.prefill.block_table,
|
||||
query_lens=attn_metadata.prefill.query_lens,
|
||||
context_lens=attn_metadata.prefill.context_lens,
|
||||
kv_b_proj=self.kv_b_proj,
|
||||
max_query_len=attn_metadata.prefill.max_query_len,
|
||||
max_context_len=attn_metadata.prefill.max_seq_lens,
|
||||
nope_dim=self.qk_nope_head_dim,
|
||||
rope_dim=self.qk_rope_head_dim,
|
||||
v_head_dim=self.v_head_dim,
|
||||
scale=self.scale,
|
||||
alibi_slopes=None,
|
||||
causal=True)
|
||||
attn_lse = torch.empty(self.num_heads,
|
||||
num_tokens,
|
||||
dtype=torch.float32,
|
||||
device=q_nope.device)
|
||||
if self.prefill_mask is None:
|
||||
if q_nope.dtype == torch.float16:
|
||||
mask_value = torch.finfo(torch.float32).min
|
||||
else:
|
||||
mask_value = 1
|
||||
prefill_mask = torch.triu(
|
||||
torch.ones(self.ring_mla_mask_size,
|
||||
self.ring_mla_mask_size,
|
||||
device=q_nope.device,
|
||||
dtype=q_nope.dtype), 1)
|
||||
self.prefill_mask = torch.where(prefill_mask == 1, mask_value,
|
||||
0).to(q_nope.dtype)
|
||||
torch_npu.atb.npu_ring_mla(q_nope=q_nope,
|
||||
q_rope=q_pe,
|
||||
k_nope=k_nope,
|
||||
k_rope=k_pe,
|
||||
value=value,
|
||||
mask=self.prefill_mask,
|
||||
seqlen=torch.tensor(
|
||||
attn_metadata.prefill.query_lens,
|
||||
dtype=torch.int32),
|
||||
head_num=self.num_heads,
|
||||
kv_head_num=self.num_heads,
|
||||
pre_out=None,
|
||||
prev_lse=None,
|
||||
qk_scale=self.scale,
|
||||
kernel_type="kernel_type_high_precision",
|
||||
mask_type="mask_type_triu",
|
||||
input_layout="type_bsnd",
|
||||
calc_type="calc_type_first_ring",
|
||||
output=attn_output,
|
||||
softmax_lse=attn_lse)
|
||||
attn_output, attn_lse = self._compute_prefill_context( \
|
||||
q_nope, q_pe, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse)
|
||||
|
||||
attn_output = attn_output.reshape(
|
||||
[num_tokens, self.num_heads * self.v_head_dim])
|
||||
if attn_metadata.attn_state in [
|
||||
AscendAttentionState.ChunkedPrefill,
|
||||
AscendAttentionState.SpecDecoding,
|
||||
AscendAttentionState.PrefillCacheHit
|
||||
] and not self.chunked_prefill_for_mla:
|
||||
attn_output = attn_output_torch
|
||||
return attn_output
|
||||
|
||||
def exec_kv_decode(
|
||||
|
||||
Reference in New Issue
Block a user