diff --git a/docs/source/locale/zh_CN/LC_MESSAGES/user_guide/configuration/additional_config.po b/docs/source/locale/zh_CN/LC_MESSAGES/user_guide/configuration/additional_config.po index 54dacd6..f59ccbf 100644 --- a/docs/source/locale/zh_CN/LC_MESSAGES/user_guide/configuration/additional_config.po +++ b/docs/source/locale/zh_CN/LC_MESSAGES/user_guide/configuration/additional_config.po @@ -148,10 +148,6 @@ msgid "" " to be passed in." msgstr "在为MOE模型使用专家负载均衡时,需要传入专家映射路径。" -#: ../../user_guide/configuration/additional_config.md -msgid "`chunked_prefill_for_mla`" -msgstr "`chunked_prefill_for_mla`" - #: ../../user_guide/configuration/additional_config.md msgid "`False`" msgstr "`False`" diff --git a/docs/source/user_guide/configuration/additional_config.md b/docs/source/user_guide/configuration/additional_config.md index eddb1c4..56fcbd8 100644 --- a/docs/source/user_guide/configuration/additional_config.md +++ b/docs/source/user_guide/configuration/additional_config.md @@ -30,7 +30,6 @@ The following table lists the additional configuration options available in vLLM | `ascend_scheduler_config` | dict | `{}` | The config options for ascend scheduler | | `refresh` | bool | `false` | Whether to refresh global ascend config content. This value is usually used by rlhf or ut/e2e test case. | | `expert_map_path` | str | `None` | When using expert load balancing for the MOE model, an expert map path needs to be passed in. | -| `chunked_prefill_for_mla` | bool | `False` | Whether to enable the fused operator-like chunked_prefill. | | `enable_prefetch` | bool | `False` | Whether to enable weight prefetch. | | `kv_cache_dtype` | str | `None` | When using the kv cache quantization method, kv cache dtype needs to be set, currently only int8 is supported. | | `enable_shared_expert_dp` | bool | `False` | When the shared expert in DP, it has better performance but consumes more memory. Currently only DeepSeek series models are supported to use. | diff --git a/examples/disaggregated_prefill_v1/README.md b/examples/disaggregated_prefill_v1/README.md index c42cace..fabcf6b 100644 --- a/examples/disaggregated_prefill_v1/README.md +++ b/examples/disaggregated_prefill_v1/README.md @@ -70,9 +70,7 @@ vllm serve /models/deepseek_r1_w8a8 \ "kv_port": "20001", "engine_id": "0", "kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector" - }' \ - --additional-config \ - '{"chunked_prefill_for_mla":true}' + }' ``` Run prefill server P2 on second node: @@ -114,9 +112,7 @@ vllm serve /models/deepseek_r1_w8a8 \ "kv_port": "20001", "engine_id": "0", "kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector" - }' \ - --additional-config \ - '{"chunked_prefill_for_mla":true}' + }' ``` Run decode server d1 on third node: diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index d287bad..6ee943d 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -20,7 +20,6 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig from vllm_ascend.multistream.context import get_multistream_comm_context from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn -from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla from vllm_ascend.utils import npu_prefetch from vllm_ascend.worker.npu_input_batch import InputBatch @@ -491,7 +490,6 @@ class AscendMLAImpl(MLAAttentionImpl): self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp self.enable_prefetch = ascend_config.enable_prefetch self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz - self.chunked_prefill_for_mla = ascend_config.chunked_prefill_for_mla vllm_config = get_current_vllm_config() self.ring_mla_mask_size = 512 @@ -673,84 +671,47 @@ class AscendMLAImpl(MLAAttentionImpl): self.v_head_dim, dtype=q_nope.dtype, device=q_nope.device) - if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: - query = torch.cat((q_nope, q_pe), dim=-1) - key = torch.cat((k_nope, k_pe), dim=-1) - torch_npu._npu_flash_attention( - query=query, - key=key, - value=value, - mask=attn_metadata.attn_mask, - seq_len=attn_metadata.prefill.context_lens, - scale_value=self.scale, - num_heads=self.num_heads, - num_kv_heads=self.num_heads, - out=attn_output) - elif self.chunked_prefill_for_mla: - attn_lse = torch.empty(self.num_heads, - num_tokens, - dtype=torch.float32, - device=q_nope.device) - if self.prefill_mask is None: - self.prefill_mask = torch.triu( - torch.ones(self.ring_mla_mask_size, - self.ring_mla_mask_size, - device=q_nope.device, - dtype=q_nope.dtype), 1) - torch_npu.atb.npu_ring_mla( - q_nope=q_nope, - q_rope=q_pe, - k_nope=k_nope, - k_rope=k_pe, - value=value, - mask=self.prefill_mask, - seqlen=torch.tensor(attn_metadata.prefill.query_lens, - dtype=torch.int32), - head_num=self.num_heads, - kv_head_num=self.num_heads, - pre_out=None, - prev_lse=None, - qk_scale=self.scale, - kernel_type="kernel_type_high_precision", - mask_type="mask_type_triu", - input_layout="type_bsnd", - calc_type="calc_type_first_ring", - output=attn_output, - softmax_lse=attn_lse) - attn_output, attn_lse = self._compute_prefill_context( \ - q_nope, q_pe, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse) - else: - query = torch.cat((q_nope, q_pe), dim=-1) - attn_output_torch = torch.empty(num_tokens, - self.num_heads * self.v_head_dim, - dtype=query.dtype, - device=query.device) - # current requests is chunked in prefill, disable flash attention with chunked prefill - vanilla_chunked_prefill_mla( - output=attn_output_torch, - query=query, - kv_cache=kv_c_and_k_pe_cache, - block_tables=attn_metadata.prefill.block_table, - query_lens=attn_metadata.prefill.query_lens, - context_lens=attn_metadata.prefill.context_lens, - kv_b_proj=self.kv_b_proj, - max_query_len=attn_metadata.prefill.max_query_len, - max_context_len=attn_metadata.prefill.max_seq_lens, - nope_dim=self.qk_nope_head_dim, - rope_dim=self.qk_rope_head_dim, - v_head_dim=self.v_head_dim, - scale=self.scale, - alibi_slopes=None, - causal=True) + attn_lse = torch.empty(self.num_heads, + num_tokens, + dtype=torch.float32, + device=q_nope.device) + if self.prefill_mask is None: + if q_nope.dtype == torch.float16: + mask_value = torch.finfo(torch.float32).min + else: + mask_value = 1 + prefill_mask = torch.triu( + torch.ones(self.ring_mla_mask_size, + self.ring_mla_mask_size, + device=q_nope.device, + dtype=q_nope.dtype), 1) + self.prefill_mask = torch.where(prefill_mask == 1, mask_value, + 0).to(q_nope.dtype) + torch_npu.atb.npu_ring_mla(q_nope=q_nope, + q_rope=q_pe, + k_nope=k_nope, + k_rope=k_pe, + value=value, + mask=self.prefill_mask, + seqlen=torch.tensor( + attn_metadata.prefill.query_lens, + dtype=torch.int32), + head_num=self.num_heads, + kv_head_num=self.num_heads, + pre_out=None, + prev_lse=None, + qk_scale=self.scale, + kernel_type="kernel_type_high_precision", + mask_type="mask_type_triu", + input_layout="type_bsnd", + calc_type="calc_type_first_ring", + output=attn_output, + softmax_lse=attn_lse) + attn_output, attn_lse = self._compute_prefill_context( \ + q_nope, q_pe, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse) attn_output = attn_output.reshape( [num_tokens, self.num_heads * self.v_head_dim]) - if attn_metadata.attn_state in [ - AscendAttentionState.ChunkedPrefill, - AscendAttentionState.SpecDecoding, - AscendAttentionState.PrefillCacheHit - ] and not self.chunked_prefill_for_mla: - attn_output = attn_output_torch return attn_output def exec_kv_decode( diff --git a/vllm_ascend/torchair/torchair_mla.py b/vllm_ascend/torchair/torchair_mla.py index f641f0d..c4b9ac2 100644 --- a/vllm_ascend/torchair/torchair_mla.py +++ b/vllm_ascend/torchair/torchair_mla.py @@ -23,7 +23,6 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig from vllm_ascend.multistream.context import get_multistream_comm_context from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn -from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla from vllm_ascend.torchair.utils import (TorchairCommonAttentionMetadata, npu_stream_switch, npu_wait_tensor) from vllm_ascend.utils import npu_prefetch @@ -674,6 +673,8 @@ class AscendMLATorchairImpl(MLAAttentionImpl): self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp self.running_in_graph = False + self.prefill_mask = None + self.ring_mla_mask_size = 512 # Adapt torch air graph mode with spec decoding. speculative_config = get_current_vllm_config().speculative_config @@ -820,16 +821,13 @@ class AscendMLATorchairImpl(MLAAttentionImpl): k_nope, v = kv_nope\ .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) k_pe = k_pe.expand((*k_nope.shape[:-1], -1)) - mask = torch.triu( - torch.ones(512, 512, device=query.device, dtype=query.dtype), - 1) torch_npu.atb.npu_ring_mla( q_nope=q_nope, q_rope=q_pe, k_nope=k_nope, k_rope=k_pe, value=v, - mask=mask, + mask=self.prefill_mask, seqlen=seq_len, head_num=self.num_heads, kv_head_num=self.num_heads, @@ -861,104 +859,54 @@ class AscendMLATorchairImpl(MLAAttentionImpl): self.v_head_dim, dtype=query.dtype, device=query.device) + attn_lse = torch.empty(self.num_heads, + num_tokens, + dtype=torch.float32, + device=query.device) k_nope, value = self.kv_b_proj(kv_c_normed)[0].view( -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).split( [self.qk_nope_head_dim, self.v_head_dim], dim=-1) k_pe = k_pe.expand((*k_nope.shape[:-1], -1)) # Here is only 2 possibility of input, ChunkedPrefill or PrefillNoCache - ascend_config = get_ascend_config() + q_pe = query[..., self.qk_nope_head_dim:] + q_nope = query[..., :self.qk_nope_head_dim] + if self.prefill_mask is None: + if q_nope.dtype == torch.float16: + mask_value = torch.finfo(torch.float32).min + else: + mask_value = 1 + prefill_mask = torch.triu( + torch.ones(self.ring_mla_mask_size, + self.ring_mla_mask_size, + device=q_nope.device, + dtype=q_nope.dtype), 1) + self.prefill_mask = torch.where(prefill_mask == 1, mask_value, + 0).to(q_nope.dtype) + torch_npu.atb.npu_ring_mla(q_nope=q_nope, + q_rope=q_pe, + k_nope=k_nope, + k_rope=k_pe, + value=value, + mask=self.prefill_mask, + seqlen=torch.tensor( + attn_metadata.prefill.query_lens, + dtype=torch.int32), + head_num=self.num_heads, + kv_head_num=self.num_heads, + pre_out=None, + prev_lse=None, + qk_scale=self.scale, + kernel_type="kernel_type_high_precision", + mask_type="mask_type_triu", + input_layout="type_bsnd", + calc_type="calc_type_first_ring", + output=attn_output, + softmax_lse=attn_lse) + attn_output, attn_lse = self._compute_prefill_context( \ + query, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse) - if attn_metadata.attn_state in [ - AscendAttentionState.ChunkedPrefill, - AscendAttentionState.SpecDecoding, - AscendAttentionState.PrefillCacheHit - ] and not ascend_config.chunked_prefill_for_mla: - attn_output_torch = torch.empty(num_tokens, - self.num_heads * self.v_head_dim, - dtype=query.dtype, - device=query.device) - # current requests is chunked in prefill, disable flash attention with chunked prefill - vanilla_chunked_prefill_mla( - output=attn_output_torch, - query=query, - kv_cache=kv_c_and_k_pe_cache, - block_tables=attn_metadata.prefill.block_table, - query_lens=attn_metadata.prefill.query_lens, - context_lens=attn_metadata.prefill.context_lens, - kv_b_proj=self.kv_b_proj, - max_query_len=attn_metadata.prefill.max_query_len, - max_context_len=attn_metadata.prefill.max_seq_lens, - nope_dim=self.qk_nope_head_dim, - rope_dim=self.qk_rope_head_dim, - v_head_dim=self.v_head_dim, - scale=self.scale, - alibi_slopes=None, - causal=True) - elif attn_metadata.attn_state in [ - AscendAttentionState.ChunkedPrefill, - AscendAttentionState.SpecDecoding, - AscendAttentionState.PrefillCacheHit - ]: - attn_lse = torch.empty(self.num_heads, - num_tokens, - dtype=torch.float32, - device=query.device) - q_pe = query[..., self.qk_nope_head_dim:] - q_nope = query[..., :self.qk_nope_head_dim] - mask = torch.triu( - torch.ones(512, 512, device=query.device, dtype=query.dtype), - 1) # 512: mask only support 512 - if attn_metadata.num_prefills > 1: - mask = mask.unsqueeze(0).repeat(attn_metadata.num_prefills, 1, - 1) - torch_npu.atb.npu_ring_mla( - q_nope=q_nope, - q_rope=q_pe, - k_nope=k_nope, - k_rope=k_pe, - value=value, - mask=mask, - seqlen=torch.tensor(attn_metadata.prefill.query_lens, - dtype=torch.int32), - head_num=self.num_heads, - kv_head_num=self.num_heads, - pre_out=None, - prev_lse=None, - qk_scale=self.scale, - kernel_type="kernel_type_high_precision", - mask_type="mask_type_triu", - input_layout="type_bsnd", - calc_type="calc_type_first_ring", - output=attn_output, - softmax_lse=attn_lse) - attn_output, attn_lse = self._compute_prefill_context( \ - query, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse) - - elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: - key = torch.cat((k_nope, k_pe), dim=-1) - torch_npu._npu_flash_attention( - query=query, - key=key, - value=value, - mask=attn_metadata.attn_mask, - seq_len=attn_metadata.prefill.context_lens, - scale_value=self.scale, - num_heads=self.num_heads, - num_kv_heads=self.num_heads, - out=attn_output) - attn_output = attn_output.view(-1, self.num_heads, self.v_head_dim) - else: - raise RuntimeError( - "Unexpected path reached, AscendMLATorchairImpl should only have PrefillNoCache, PrefillCacheHit, ChunkedPrefill and SpecDecoding scenario in forward prefill, please file a bug to vllm-ascend !" - ) attn_output = attn_output.reshape( [num_tokens, self.num_heads * self.v_head_dim]) - if attn_metadata.attn_state in [ - AscendAttentionState.ChunkedPrefill, - AscendAttentionState.SpecDecoding, - AscendAttentionState.PrefillCacheHit - ] and not ascend_config.chunked_prefill_for_mla: - attn_output = attn_output_torch return attn_output