[Bugfix] Fix the input constraints checks for the mlapo and bmm_transpose operators (#5764)

### What this PR does / why we need it?
This PR fix the input constraints checks for the mlapo and bmm_transpose
operators.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
CI passed with new added/existing test.

- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef

### Perf
64K/3K,1P1D,bs=32

before this pr:
TPOT 29ms, TTFT 47s,TPS 606 token/s

after this pr:
TPOT 29ms, TTFT 48s,TPS 636 token/s

Signed-off-by: rjg-lyh <1318825571@qq.com>
This commit is contained in:
rjg-lyh
2026-01-16 17:52:48 +08:00
committed by GitHub
parent 4f446aec4c
commit 3af91e5ac4
3 changed files with 28 additions and 37 deletions

View File

@@ -57,6 +57,8 @@ else:
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024
BUILD_METADATA_STEP_PREFILL = 0
BUILD_METADATA_STEP_DECODE = 1
# token count limits within the mlapo operator
MLAPO_MAX_SUPPORTED_TOKENS = 1024
class AscendMLABackend(AttentionBackend):
@@ -927,10 +929,9 @@ class AscendMLAImpl(MLAAttentionImpl):
# On KV consumers (decode-only) MLAPO uses the transformed weights built above;
# the original fused_qkv_a_proj/q_proj weights and quant params are no longer
# referenced, so drop them to save memory.
ascend_config = get_ascend_config()
if self.vllm_config.kv_transfer_config is not None and \
self.vllm_config.kv_transfer_config.is_kv_consumer and \
ascend_config.recompute_scheduler_enable:
self.vllm_config.scheduler_config.max_num_batched_tokens <= MLAPO_MAX_SUPPORTED_TOKENS:
self.fused_qkv_a_proj.weight = None
self.fused_qkv_a_proj.deq_scale = None
self.fused_qkv_a_proj.quant_bias = None
@@ -1508,7 +1509,9 @@ class AscendMLAImpl(MLAAttentionImpl):
device=hidden_states.device)
# MLA Preprocess
if self.enable_mlapo and not has_prefill:
if self.enable_mlapo and \
not has_prefill and \
attn_metadata.num_decode_tokens <= MLAPO_MAX_SUPPORTED_TOKENS:
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
hidden_states.contiguous(), need_gather_q_kv)
decode_preprocess_res, prefill_preprocess_res = self._mla_preprocess_only_decode(