From 3af91e5ac4124ea5e2f5a286ebc0c46ed5f0a368 Mon Sep 17 00:00:00 2001 From: rjg-lyh <83491835+rjg-lyh@users.noreply.github.com> Date: Fri, 16 Jan 2026 17:52:48 +0800 Subject: [PATCH] [Bugfix] Fix the input constraints checks for the mlapo and bmm_transpose operators (#5764) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What this PR does / why we need it? This PR fix the input constraints checks for the mlapo and bmm_transpose operators. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI passed with new added/existing test. - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2f4e6548efec402b913ffddc8726230d9311948d ### Perf 64K/3K,1P1D,bs=32 before this pr: TPOT 29ms, TTFT 47s,TPS 606 token/s after this pr: TPOT 29ms, TTFT 48s,TPS 636 token/s Signed-off-by: rjg-lyh <1318825571@qq.com> --- tests/ut/attention/test_sfa_v1.py | 3 -- vllm_ascend/attention/mla_v1.py | 9 ++++-- vllm_ascend/attention/sfa_v1.py | 53 +++++++++++++------------------ 3 files changed, 28 insertions(+), 37 deletions(-) diff --git a/tests/ut/attention/test_sfa_v1.py b/tests/ut/attention/test_sfa_v1.py index 2fdddf12..4bcfd3c6 100644 --- a/tests/ut/attention/test_sfa_v1.py +++ b/tests/ut/attention/test_sfa_v1.py @@ -36,7 +36,6 @@ class TestAscendSFABackend(TestBase): class TestAscendSFAMetadata(TestBase): def test_ascend_sfa_metadata_default(self): - has_prefill = True num_actual_tokens = 100 slot_mapping = torch.randn(100, 4, 1024) seq_lens = torch.tensor([30, 50]) @@ -54,7 +53,6 @@ class TestAscendSFAMetadata(TestBase): attn_state = AscendAttentionState.ChunkedPrefill metadata = AscendSFAMetadata( - has_prefill=has_prefill, num_actual_tokens=num_actual_tokens, slot_mapping=slot_mapping, seq_lens=seq_lens, @@ -68,7 +66,6 @@ class TestAscendSFAMetadata(TestBase): attn_state=attn_state, ) - self.assertEqual(metadata.has_prefill, has_prefill) self.assertEqual(metadata.num_actual_tokens, num_actual_tokens) self.assertIs(metadata.slot_mapping, slot_mapping) self.assertTrue(torch.equal(metadata.seq_lens, seq_lens)) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 74dbe7da..128d7547 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -57,6 +57,8 @@ else: MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 BUILD_METADATA_STEP_PREFILL = 0 BUILD_METADATA_STEP_DECODE = 1 +# token count limits within the mlapo operator +MLAPO_MAX_SUPPORTED_TOKENS = 1024 class AscendMLABackend(AttentionBackend): @@ -927,10 +929,9 @@ class AscendMLAImpl(MLAAttentionImpl): # On KV consumers (decode-only) MLAPO uses the transformed weights built above; # the original fused_qkv_a_proj/q_proj weights and quant params are no longer # referenced, so drop them to save memory. - ascend_config = get_ascend_config() if self.vllm_config.kv_transfer_config is not None and \ self.vllm_config.kv_transfer_config.is_kv_consumer and \ - ascend_config.recompute_scheduler_enable: + self.vllm_config.scheduler_config.max_num_batched_tokens <= MLAPO_MAX_SUPPORTED_TOKENS: self.fused_qkv_a_proj.weight = None self.fused_qkv_a_proj.deq_scale = None self.fused_qkv_a_proj.quant_bias = None @@ -1508,7 +1509,9 @@ class AscendMLAImpl(MLAAttentionImpl): device=hidden_states.device) # MLA Preprocess - if self.enable_mlapo and not has_prefill: + if self.enable_mlapo and \ + not has_prefill and \ + attn_metadata.num_decode_tokens <= MLAPO_MAX_SUPPORTED_TOKENS: hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( hidden_states.contiguous(), need_gather_q_kv) decode_preprocess_res, prefill_preprocess_res = self._mla_preprocess_only_decode( diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index 0fd62499..3cf1d9e3 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -18,8 +18,9 @@ from vllm_ascend import envs from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import AscendAttentionState -from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE +from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE, MLAPO_MAX_SUPPORTED_TOKENS from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, + maybe_save_kv_layer_to_connector, trans_rope_weight, transdata, wait_for_kv_layer_from_connector) from vllm_ascend.distributed.utils import all_gather_async @@ -47,6 +48,9 @@ else: AttentionBackend, AttentionCGSupport, MLAAttentionImpl) # isort: on +# token count limits within bmm_transpose operator +BMM_TRANS_MAX_SUPPORTED_TOKENS = 1024 + class AscendSFABackend(AttentionBackend): @@ -99,7 +103,6 @@ class AscendSFAMetadata: # |---------- context_len ----------| # |-------------------- seq_len ---------------------| # |-- query_len ---| - has_prefill: bool num_actual_tokens: int # Number of tokens excluding padding. slot_mapping: torch.Tensor seq_lens: torch.Tensor @@ -196,16 +199,11 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]): input_positions = common_attn_metadata.positions[: num_input_tokens].long( ) - query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu - query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] - has_prefill = any(query_lens_cpu > self.decode_threshold) cum_query_lens = common_attn_metadata.query_start_loc[1:num_reqs + 1] seq_lens = common_attn_metadata.seq_lens[:num_reqs] - if has_prefill: - cos, sin = get_cos_and_sin_mla(input_positions) - else: - cos, sin = get_cos_and_sin_mla(input_positions, True) + + cos, sin = get_cos_and_sin_mla(input_positions, True) sfa_cp_context = None if self.enable_sfa_cp: @@ -285,7 +283,6 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]): ) return self.metadata_cls( # type: ignore - has_prefill=has_prefill, num_input_tokens=common_attn_metadata.num_input_tokens, num_actual_tokens=num_actual_tokens, cum_query_lens=cum_query_lens, @@ -368,7 +365,6 @@ class AscendSFAImpl(MLAAttentionImpl): self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.tp_size = get_tensor_model_parallel_world_size() self.tp_rank = get_tp_group().rank_in_group - self.num_heads_per_rank = self.num_heads // self.tp_size self.q_b_proj = kwargs['q_b_proj'] ascend_config = get_ascend_config() @@ -469,21 +465,17 @@ class AscendSFAImpl(MLAAttentionImpl): # if mlapo, W_UK_T can't trans nz self.W_UK_T = maybe_trans_nz(self.W_UK_T) - def _v_up_proj(self, x, has_prefill: bool): - # TODO(zzzzwwjj): We should not judge by whether `has_prefill` or not. - # The true criteria for judgment is tensorA's shape[0] <= 1024 (num_tokens <= 1024). - # This is a bug in the previous code. + def _v_up_proj(self, x): + num_input_tokens, _, _ = x.shape if x.dtype in [torch.float16, torch.bfloat16] \ and hasattr(torch.ops._C_ascend, "batch_matmul_transpose") \ - and not self.enable_sfa_cp \ - and not has_prefill: - x = x.view(-1, self.num_heads, self.kv_lora_rank) - b, _, _ = x.shape - res = torch.empty((b, self.num_heads, self.v_head_dim), + and num_input_tokens <= BMM_TRANS_MAX_SUPPORTED_TOKENS: + x = x.view(-1, self.local_num_heads, self.kv_lora_rank) + res = torch.empty((num_input_tokens, self.local_num_heads, self.v_head_dim), dtype=x.dtype, device=x.device) torch.ops._C_ascend.batch_matmul_transpose(x, self.W_UV, res) - x = res.reshape(-1, self.num_heads * self.v_head_dim) + x = res.reshape(-1, self.local_num_heads * self.v_head_dim) else: # Convert from (B, N, L) to (N, B, L) x = x.view(-1, self.local_num_heads, @@ -654,10 +646,9 @@ class AscendSFAImpl(MLAAttentionImpl): # On KV consumers (decode-only) MLAPO uses the transformed weights built above; # the original fused_qkv_a_proj/q_proj weights and quant params are no longer # referenced, so drop them to save memory. - ascend_config = get_ascend_config() if self.vllm_config.kv_transfer_config is not None and \ self.vllm_config.kv_transfer_config.is_kv_consumer and \ - ascend_config.recompute_scheduler_enable: + self.vllm_config.scheduler_config.max_num_batched_tokens <= MLAPO_MAX_SUPPORTED_TOKENS: self.fused_qkv_a_proj.weight = None self.fused_qkv_a_proj.deq_scale = None self.fused_qkv_a_proj.quant_bias = None @@ -745,7 +736,6 @@ class AscendSFAImpl(MLAAttentionImpl): reach_layer_for_shard_weight_series(layer) return output.fill_(0) - has_prefill = attn_metadata.has_prefill cos = attn_metadata.cos sin = attn_metadata.sin actual_seq_lengths_query = attn_metadata.cum_query_lens @@ -753,17 +743,16 @@ class AscendSFAImpl(MLAAttentionImpl): if self.enable_sfa_cp: need_gather_q_kv = False # Inputs and outputs may be padded for CUDA graphs + num_input_tokens = attn_metadata.num_input_tokens output_padded = output - # TODO(zzzzwwjj): In sfa, prefill and decode have the same calculation formula, - # so `has_prefill` here is not necessary. - if self.enable_mlapo and not has_prefill: + if self.enable_mlapo and num_input_tokens <= MLAPO_MAX_SUPPORTED_TOKENS: hidden_states, ql_nope, q_pe, q_c = self._sfa_preprocess_decode( hidden_states=hidden_states, kv_cache=kv_cache, attn_metadata=attn_metadata, need_gather_q_kv=need_gather_q_kv, - num_input_tokens=attn_metadata.num_input_tokens, + num_input_tokens=num_input_tokens, ) q, k = self.indexer_select_pre_process( x=hidden_states, @@ -796,8 +785,7 @@ class AscendSFAImpl(MLAAttentionImpl): sin=sin, need_gather_q_kv=need_gather_q_kv) - if has_prefill: - wait_for_kv_layer_from_connector(layer_name) + wait_for_kv_layer_from_connector(layer_name) slot_mapping = attn_metadata.slot_mapping if self.enable_sfa_cp: @@ -875,12 +863,15 @@ class AscendSFAImpl(MLAAttentionImpl): sparse_mode=3, ) - attn_output = self._v_up_proj(attn_output, has_prefill) + attn_output = self._v_up_proj(attn_output) maybe_npu_prefetch(inputs=self.o_proj.weight, dependency=attn_output, max_size=MAX_O_PROJ_PREFETCH_SIZE, enabled=self.enable_prefetch) output[...] = self.o_proj(attn_output)[0] + + maybe_save_kv_layer_to_connector(layer_name, list(kv_cache)) + return output_padded def indexer_select_pre_process(