Remove chunked_prefill_for_mla and fix ring_mla bug (#2781)
### What this PR does / why we need it?
Remove chunked prefill for mla branch in mla , and change dtype of
prefill_mask to avoid accuracy problem
### Does this PR introduce _any_ user-facing change?
NO
### How was this patch tested?
- vLLM version: v0.10.2
- vLLM main:
ef7eefe17a
---------
Signed-off-by: SunnyLee219 <3294305115@qq.com>
This commit is contained in:
@@ -148,10 +148,6 @@ msgid ""
|
||||
" to be passed in."
|
||||
msgstr "在为MOE模型使用专家负载均衡时,需要传入专家映射路径。"
|
||||
|
||||
#: ../../user_guide/configuration/additional_config.md
|
||||
msgid "`chunked_prefill_for_mla`"
|
||||
msgstr "`chunked_prefill_for_mla`"
|
||||
|
||||
#: ../../user_guide/configuration/additional_config.md
|
||||
msgid "`False`"
|
||||
msgstr "`False`"
|
||||
|
||||
@@ -30,7 +30,6 @@ The following table lists the additional configuration options available in vLLM
|
||||
| `ascend_scheduler_config` | dict | `{}` | The config options for ascend scheduler |
|
||||
| `refresh` | bool | `false` | Whether to refresh global ascend config content. This value is usually used by rlhf or ut/e2e test case. |
|
||||
| `expert_map_path` | str | `None` | When using expert load balancing for the MOE model, an expert map path needs to be passed in. |
|
||||
| `chunked_prefill_for_mla` | bool | `False` | Whether to enable the fused operator-like chunked_prefill. |
|
||||
| `enable_prefetch` | bool | `False` | Whether to enable weight prefetch. |
|
||||
| `kv_cache_dtype` | str | `None` | When using the kv cache quantization method, kv cache dtype needs to be set, currently only int8 is supported. |
|
||||
| `enable_shared_expert_dp` | bool | `False` | When the shared expert in DP, it has better performance but consumes more memory. Currently only DeepSeek series models are supported to use. |
|
||||
|
||||
@@ -70,9 +70,7 @@ vllm serve /models/deepseek_r1_w8a8 \
|
||||
"kv_port": "20001",
|
||||
"engine_id": "0",
|
||||
"kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector"
|
||||
}' \
|
||||
--additional-config \
|
||||
'{"chunked_prefill_for_mla":true}'
|
||||
}'
|
||||
```
|
||||
|
||||
Run prefill server P2 on second node:
|
||||
@@ -114,9 +112,7 @@ vllm serve /models/deepseek_r1_w8a8 \
|
||||
"kv_port": "20001",
|
||||
"engine_id": "0",
|
||||
"kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector"
|
||||
}' \
|
||||
--additional-config \
|
||||
'{"chunked_prefill_for_mla":true}'
|
||||
}'
|
||||
```
|
||||
|
||||
Run decode server d1 on third node:
|
||||
|
||||
@@ -20,7 +20,6 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
||||
from vllm_ascend.multistream.context import get_multistream_comm_context
|
||||
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
||||
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
|
||||
from vllm_ascend.utils import npu_prefetch
|
||||
from vllm_ascend.worker.npu_input_batch import InputBatch
|
||||
|
||||
@@ -491,7 +490,6 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||
self.enable_prefetch = ascend_config.enable_prefetch
|
||||
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
|
||||
self.chunked_prefill_for_mla = ascend_config.chunked_prefill_for_mla
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.ring_mla_mask_size = 512
|
||||
@@ -673,84 +671,47 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
self.v_head_dim,
|
||||
dtype=q_nope.dtype,
|
||||
device=q_nope.device)
|
||||
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||
query = torch.cat((q_nope, q_pe), dim=-1)
|
||||
key = torch.cat((k_nope, k_pe), dim=-1)
|
||||
torch_npu._npu_flash_attention(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
mask=attn_metadata.attn_mask,
|
||||
seq_len=attn_metadata.prefill.context_lens,
|
||||
scale_value=self.scale,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_heads,
|
||||
out=attn_output)
|
||||
elif self.chunked_prefill_for_mla:
|
||||
attn_lse = torch.empty(self.num_heads,
|
||||
num_tokens,
|
||||
dtype=torch.float32,
|
||||
device=q_nope.device)
|
||||
if self.prefill_mask is None:
|
||||
self.prefill_mask = torch.triu(
|
||||
torch.ones(self.ring_mla_mask_size,
|
||||
self.ring_mla_mask_size,
|
||||
device=q_nope.device,
|
||||
dtype=q_nope.dtype), 1)
|
||||
torch_npu.atb.npu_ring_mla(
|
||||
q_nope=q_nope,
|
||||
q_rope=q_pe,
|
||||
k_nope=k_nope,
|
||||
k_rope=k_pe,
|
||||
value=value,
|
||||
mask=self.prefill_mask,
|
||||
seqlen=torch.tensor(attn_metadata.prefill.query_lens,
|
||||
dtype=torch.int32),
|
||||
head_num=self.num_heads,
|
||||
kv_head_num=self.num_heads,
|
||||
pre_out=None,
|
||||
prev_lse=None,
|
||||
qk_scale=self.scale,
|
||||
kernel_type="kernel_type_high_precision",
|
||||
mask_type="mask_type_triu",
|
||||
input_layout="type_bsnd",
|
||||
calc_type="calc_type_first_ring",
|
||||
output=attn_output,
|
||||
softmax_lse=attn_lse)
|
||||
attn_output, attn_lse = self._compute_prefill_context( \
|
||||
q_nope, q_pe, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse)
|
||||
else:
|
||||
query = torch.cat((q_nope, q_pe), dim=-1)
|
||||
attn_output_torch = torch.empty(num_tokens,
|
||||
self.num_heads * self.v_head_dim,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
# current requests is chunked in prefill, disable flash attention with chunked prefill
|
||||
vanilla_chunked_prefill_mla(
|
||||
output=attn_output_torch,
|
||||
query=query,
|
||||
kv_cache=kv_c_and_k_pe_cache,
|
||||
block_tables=attn_metadata.prefill.block_table,
|
||||
query_lens=attn_metadata.prefill.query_lens,
|
||||
context_lens=attn_metadata.prefill.context_lens,
|
||||
kv_b_proj=self.kv_b_proj,
|
||||
max_query_len=attn_metadata.prefill.max_query_len,
|
||||
max_context_len=attn_metadata.prefill.max_seq_lens,
|
||||
nope_dim=self.qk_nope_head_dim,
|
||||
rope_dim=self.qk_rope_head_dim,
|
||||
v_head_dim=self.v_head_dim,
|
||||
scale=self.scale,
|
||||
alibi_slopes=None,
|
||||
causal=True)
|
||||
attn_lse = torch.empty(self.num_heads,
|
||||
num_tokens,
|
||||
dtype=torch.float32,
|
||||
device=q_nope.device)
|
||||
if self.prefill_mask is None:
|
||||
if q_nope.dtype == torch.float16:
|
||||
mask_value = torch.finfo(torch.float32).min
|
||||
else:
|
||||
mask_value = 1
|
||||
prefill_mask = torch.triu(
|
||||
torch.ones(self.ring_mla_mask_size,
|
||||
self.ring_mla_mask_size,
|
||||
device=q_nope.device,
|
||||
dtype=q_nope.dtype), 1)
|
||||
self.prefill_mask = torch.where(prefill_mask == 1, mask_value,
|
||||
0).to(q_nope.dtype)
|
||||
torch_npu.atb.npu_ring_mla(q_nope=q_nope,
|
||||
q_rope=q_pe,
|
||||
k_nope=k_nope,
|
||||
k_rope=k_pe,
|
||||
value=value,
|
||||
mask=self.prefill_mask,
|
||||
seqlen=torch.tensor(
|
||||
attn_metadata.prefill.query_lens,
|
||||
dtype=torch.int32),
|
||||
head_num=self.num_heads,
|
||||
kv_head_num=self.num_heads,
|
||||
pre_out=None,
|
||||
prev_lse=None,
|
||||
qk_scale=self.scale,
|
||||
kernel_type="kernel_type_high_precision",
|
||||
mask_type="mask_type_triu",
|
||||
input_layout="type_bsnd",
|
||||
calc_type="calc_type_first_ring",
|
||||
output=attn_output,
|
||||
softmax_lse=attn_lse)
|
||||
attn_output, attn_lse = self._compute_prefill_context( \
|
||||
q_nope, q_pe, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse)
|
||||
|
||||
attn_output = attn_output.reshape(
|
||||
[num_tokens, self.num_heads * self.v_head_dim])
|
||||
if attn_metadata.attn_state in [
|
||||
AscendAttentionState.ChunkedPrefill,
|
||||
AscendAttentionState.SpecDecoding,
|
||||
AscendAttentionState.PrefillCacheHit
|
||||
] and not self.chunked_prefill_for_mla:
|
||||
attn_output = attn_output_torch
|
||||
return attn_output
|
||||
|
||||
def exec_kv_decode(
|
||||
|
||||
@@ -23,7 +23,6 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
||||
from vllm_ascend.multistream.context import get_multistream_comm_context
|
||||
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
||||
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
|
||||
from vllm_ascend.torchair.utils import (TorchairCommonAttentionMetadata,
|
||||
npu_stream_switch, npu_wait_tensor)
|
||||
from vllm_ascend.utils import npu_prefetch
|
||||
@@ -674,6 +673,8 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
|
||||
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
|
||||
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||
self.running_in_graph = False
|
||||
self.prefill_mask = None
|
||||
self.ring_mla_mask_size = 512
|
||||
|
||||
# Adapt torch air graph mode with spec decoding.
|
||||
speculative_config = get_current_vllm_config().speculative_config
|
||||
@@ -820,16 +821,13 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
|
||||
k_nope, v = kv_nope\
|
||||
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
|
||||
mask = torch.triu(
|
||||
torch.ones(512, 512, device=query.device, dtype=query.dtype),
|
||||
1)
|
||||
torch_npu.atb.npu_ring_mla(
|
||||
q_nope=q_nope,
|
||||
q_rope=q_pe,
|
||||
k_nope=k_nope,
|
||||
k_rope=k_pe,
|
||||
value=v,
|
||||
mask=mask,
|
||||
mask=self.prefill_mask,
|
||||
seqlen=seq_len,
|
||||
head_num=self.num_heads,
|
||||
kv_head_num=self.num_heads,
|
||||
@@ -861,104 +859,54 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
|
||||
self.v_head_dim,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
attn_lse = torch.empty(self.num_heads,
|
||||
num_tokens,
|
||||
dtype=torch.float32,
|
||||
device=query.device)
|
||||
k_nope, value = self.kv_b_proj(kv_c_normed)[0].view(
|
||||
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).split(
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
|
||||
# Here is only 2 possibility of input, ChunkedPrefill or PrefillNoCache
|
||||
ascend_config = get_ascend_config()
|
||||
q_pe = query[..., self.qk_nope_head_dim:]
|
||||
q_nope = query[..., :self.qk_nope_head_dim]
|
||||
if self.prefill_mask is None:
|
||||
if q_nope.dtype == torch.float16:
|
||||
mask_value = torch.finfo(torch.float32).min
|
||||
else:
|
||||
mask_value = 1
|
||||
prefill_mask = torch.triu(
|
||||
torch.ones(self.ring_mla_mask_size,
|
||||
self.ring_mla_mask_size,
|
||||
device=q_nope.device,
|
||||
dtype=q_nope.dtype), 1)
|
||||
self.prefill_mask = torch.where(prefill_mask == 1, mask_value,
|
||||
0).to(q_nope.dtype)
|
||||
torch_npu.atb.npu_ring_mla(q_nope=q_nope,
|
||||
q_rope=q_pe,
|
||||
k_nope=k_nope,
|
||||
k_rope=k_pe,
|
||||
value=value,
|
||||
mask=self.prefill_mask,
|
||||
seqlen=torch.tensor(
|
||||
attn_metadata.prefill.query_lens,
|
||||
dtype=torch.int32),
|
||||
head_num=self.num_heads,
|
||||
kv_head_num=self.num_heads,
|
||||
pre_out=None,
|
||||
prev_lse=None,
|
||||
qk_scale=self.scale,
|
||||
kernel_type="kernel_type_high_precision",
|
||||
mask_type="mask_type_triu",
|
||||
input_layout="type_bsnd",
|
||||
calc_type="calc_type_first_ring",
|
||||
output=attn_output,
|
||||
softmax_lse=attn_lse)
|
||||
attn_output, attn_lse = self._compute_prefill_context( \
|
||||
query, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse)
|
||||
|
||||
if attn_metadata.attn_state in [
|
||||
AscendAttentionState.ChunkedPrefill,
|
||||
AscendAttentionState.SpecDecoding,
|
||||
AscendAttentionState.PrefillCacheHit
|
||||
] and not ascend_config.chunked_prefill_for_mla:
|
||||
attn_output_torch = torch.empty(num_tokens,
|
||||
self.num_heads * self.v_head_dim,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
# current requests is chunked in prefill, disable flash attention with chunked prefill
|
||||
vanilla_chunked_prefill_mla(
|
||||
output=attn_output_torch,
|
||||
query=query,
|
||||
kv_cache=kv_c_and_k_pe_cache,
|
||||
block_tables=attn_metadata.prefill.block_table,
|
||||
query_lens=attn_metadata.prefill.query_lens,
|
||||
context_lens=attn_metadata.prefill.context_lens,
|
||||
kv_b_proj=self.kv_b_proj,
|
||||
max_query_len=attn_metadata.prefill.max_query_len,
|
||||
max_context_len=attn_metadata.prefill.max_seq_lens,
|
||||
nope_dim=self.qk_nope_head_dim,
|
||||
rope_dim=self.qk_rope_head_dim,
|
||||
v_head_dim=self.v_head_dim,
|
||||
scale=self.scale,
|
||||
alibi_slopes=None,
|
||||
causal=True)
|
||||
elif attn_metadata.attn_state in [
|
||||
AscendAttentionState.ChunkedPrefill,
|
||||
AscendAttentionState.SpecDecoding,
|
||||
AscendAttentionState.PrefillCacheHit
|
||||
]:
|
||||
attn_lse = torch.empty(self.num_heads,
|
||||
num_tokens,
|
||||
dtype=torch.float32,
|
||||
device=query.device)
|
||||
q_pe = query[..., self.qk_nope_head_dim:]
|
||||
q_nope = query[..., :self.qk_nope_head_dim]
|
||||
mask = torch.triu(
|
||||
torch.ones(512, 512, device=query.device, dtype=query.dtype),
|
||||
1) # 512: mask only support 512
|
||||
if attn_metadata.num_prefills > 1:
|
||||
mask = mask.unsqueeze(0).repeat(attn_metadata.num_prefills, 1,
|
||||
1)
|
||||
torch_npu.atb.npu_ring_mla(
|
||||
q_nope=q_nope,
|
||||
q_rope=q_pe,
|
||||
k_nope=k_nope,
|
||||
k_rope=k_pe,
|
||||
value=value,
|
||||
mask=mask,
|
||||
seqlen=torch.tensor(attn_metadata.prefill.query_lens,
|
||||
dtype=torch.int32),
|
||||
head_num=self.num_heads,
|
||||
kv_head_num=self.num_heads,
|
||||
pre_out=None,
|
||||
prev_lse=None,
|
||||
qk_scale=self.scale,
|
||||
kernel_type="kernel_type_high_precision",
|
||||
mask_type="mask_type_triu",
|
||||
input_layout="type_bsnd",
|
||||
calc_type="calc_type_first_ring",
|
||||
output=attn_output,
|
||||
softmax_lse=attn_lse)
|
||||
attn_output, attn_lse = self._compute_prefill_context( \
|
||||
query, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse)
|
||||
|
||||
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||
key = torch.cat((k_nope, k_pe), dim=-1)
|
||||
torch_npu._npu_flash_attention(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
mask=attn_metadata.attn_mask,
|
||||
seq_len=attn_metadata.prefill.context_lens,
|
||||
scale_value=self.scale,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_heads,
|
||||
out=attn_output)
|
||||
attn_output = attn_output.view(-1, self.num_heads, self.v_head_dim)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Unexpected path reached, AscendMLATorchairImpl should only have PrefillNoCache, PrefillCacheHit, ChunkedPrefill and SpecDecoding scenario in forward prefill, please file a bug to vllm-ascend !"
|
||||
)
|
||||
attn_output = attn_output.reshape(
|
||||
[num_tokens, self.num_heads * self.v_head_dim])
|
||||
if attn_metadata.attn_state in [
|
||||
AscendAttentionState.ChunkedPrefill,
|
||||
AscendAttentionState.SpecDecoding,
|
||||
AscendAttentionState.PrefillCacheHit
|
||||
] and not ascend_config.chunked_prefill_for_mla:
|
||||
attn_output = attn_output_torch
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
Reference in New Issue
Block a user