[Bugfix] Fix the input constraints checks for the mlapo and bmm_transpose operators (#5764)
### What this PR does / why we need it?
This PR fix the input constraints checks for the mlapo and bmm_transpose
operators.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
CI passed with new added/existing test.
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
### Perf
64K/3K,1P1D,bs=32
before this pr:
TPOT 29ms, TTFT 47s,TPS 606 token/s
after this pr:
TPOT 29ms, TTFT 48s,TPS 636 token/s
Signed-off-by: rjg-lyh <1318825571@qq.com>
This commit is contained in:
@@ -57,6 +57,8 @@ else:
|
||||
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024
|
||||
BUILD_METADATA_STEP_PREFILL = 0
|
||||
BUILD_METADATA_STEP_DECODE = 1
|
||||
# token count limits within the mlapo operator
|
||||
MLAPO_MAX_SUPPORTED_TOKENS = 1024
|
||||
|
||||
|
||||
class AscendMLABackend(AttentionBackend):
|
||||
@@ -927,10 +929,9 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
# On KV consumers (decode-only) MLAPO uses the transformed weights built above;
|
||||
# the original fused_qkv_a_proj/q_proj weights and quant params are no longer
|
||||
# referenced, so drop them to save memory.
|
||||
ascend_config = get_ascend_config()
|
||||
if self.vllm_config.kv_transfer_config is not None and \
|
||||
self.vllm_config.kv_transfer_config.is_kv_consumer and \
|
||||
ascend_config.recompute_scheduler_enable:
|
||||
self.vllm_config.scheduler_config.max_num_batched_tokens <= MLAPO_MAX_SUPPORTED_TOKENS:
|
||||
self.fused_qkv_a_proj.weight = None
|
||||
self.fused_qkv_a_proj.deq_scale = None
|
||||
self.fused_qkv_a_proj.quant_bias = None
|
||||
@@ -1508,7 +1509,9 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
device=hidden_states.device)
|
||||
|
||||
# MLA Preprocess
|
||||
if self.enable_mlapo and not has_prefill:
|
||||
if self.enable_mlapo and \
|
||||
not has_prefill and \
|
||||
attn_metadata.num_decode_tokens <= MLAPO_MAX_SUPPORTED_TOKENS:
|
||||
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
hidden_states.contiguous(), need_gather_q_kv)
|
||||
decode_preprocess_res, prefill_preprocess_res = self._mla_preprocess_only_decode(
|
||||
|
||||
Reference in New Issue
Block a user