[Bugfix] Fix the input constraints checks for the mlapo and bmm_transpose operators (#5764)

### What this PR does / why we need it?
This PR fix the input constraints checks for the mlapo and bmm_transpose
operators.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
CI passed with new added/existing test.

- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef

### Perf
64K/3K,1P1D,bs=32

before this pr:
TPOT 29ms, TTFT 47s,TPS 606 token/s

after this pr:
TPOT 29ms, TTFT 48s,TPS 636 token/s

Signed-off-by: rjg-lyh <1318825571@qq.com>
This commit is contained in:
rjg-lyh
2026-01-16 17:52:48 +08:00
committed by GitHub
parent 4f446aec4c
commit 3af91e5ac4
3 changed files with 28 additions and 37 deletions

View File

@@ -36,7 +36,6 @@ class TestAscendSFABackend(TestBase):
class TestAscendSFAMetadata(TestBase): class TestAscendSFAMetadata(TestBase):
def test_ascend_sfa_metadata_default(self): def test_ascend_sfa_metadata_default(self):
has_prefill = True
num_actual_tokens = 100 num_actual_tokens = 100
slot_mapping = torch.randn(100, 4, 1024) slot_mapping = torch.randn(100, 4, 1024)
seq_lens = torch.tensor([30, 50]) seq_lens = torch.tensor([30, 50])
@@ -54,7 +53,6 @@ class TestAscendSFAMetadata(TestBase):
attn_state = AscendAttentionState.ChunkedPrefill attn_state = AscendAttentionState.ChunkedPrefill
metadata = AscendSFAMetadata( metadata = AscendSFAMetadata(
has_prefill=has_prefill,
num_actual_tokens=num_actual_tokens, num_actual_tokens=num_actual_tokens,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
seq_lens=seq_lens, seq_lens=seq_lens,
@@ -68,7 +66,6 @@ class TestAscendSFAMetadata(TestBase):
attn_state=attn_state, attn_state=attn_state,
) )
self.assertEqual(metadata.has_prefill, has_prefill)
self.assertEqual(metadata.num_actual_tokens, num_actual_tokens) self.assertEqual(metadata.num_actual_tokens, num_actual_tokens)
self.assertIs(metadata.slot_mapping, slot_mapping) self.assertIs(metadata.slot_mapping, slot_mapping)
self.assertTrue(torch.equal(metadata.seq_lens, seq_lens)) self.assertTrue(torch.equal(metadata.seq_lens, seq_lens))

View File

@@ -57,6 +57,8 @@ else:
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024
BUILD_METADATA_STEP_PREFILL = 0 BUILD_METADATA_STEP_PREFILL = 0
BUILD_METADATA_STEP_DECODE = 1 BUILD_METADATA_STEP_DECODE = 1
# token count limits within the mlapo operator
MLAPO_MAX_SUPPORTED_TOKENS = 1024
class AscendMLABackend(AttentionBackend): class AscendMLABackend(AttentionBackend):
@@ -927,10 +929,9 @@ class AscendMLAImpl(MLAAttentionImpl):
# On KV consumers (decode-only) MLAPO uses the transformed weights built above; # On KV consumers (decode-only) MLAPO uses the transformed weights built above;
# the original fused_qkv_a_proj/q_proj weights and quant params are no longer # the original fused_qkv_a_proj/q_proj weights and quant params are no longer
# referenced, so drop them to save memory. # referenced, so drop them to save memory.
ascend_config = get_ascend_config()
if self.vllm_config.kv_transfer_config is not None and \ if self.vllm_config.kv_transfer_config is not None and \
self.vllm_config.kv_transfer_config.is_kv_consumer and \ self.vllm_config.kv_transfer_config.is_kv_consumer and \
ascend_config.recompute_scheduler_enable: self.vllm_config.scheduler_config.max_num_batched_tokens <= MLAPO_MAX_SUPPORTED_TOKENS:
self.fused_qkv_a_proj.weight = None self.fused_qkv_a_proj.weight = None
self.fused_qkv_a_proj.deq_scale = None self.fused_qkv_a_proj.deq_scale = None
self.fused_qkv_a_proj.quant_bias = None self.fused_qkv_a_proj.quant_bias = None
@@ -1508,7 +1509,9 @@ class AscendMLAImpl(MLAAttentionImpl):
device=hidden_states.device) device=hidden_states.device)
# MLA Preprocess # MLA Preprocess
if self.enable_mlapo and not has_prefill: if self.enable_mlapo and \
not has_prefill and \
attn_metadata.num_decode_tokens <= MLAPO_MAX_SUPPORTED_TOKENS:
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
hidden_states.contiguous(), need_gather_q_kv) hidden_states.contiguous(), need_gather_q_kv)
decode_preprocess_res, prefill_preprocess_res = self._mla_preprocess_only_decode( decode_preprocess_res, prefill_preprocess_res = self._mla_preprocess_only_decode(

View File

@@ -18,8 +18,9 @@ from vllm_ascend import envs
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE, MLAPO_MAX_SUPPORTED_TOKENS
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
maybe_save_kv_layer_to_connector,
trans_rope_weight, transdata, trans_rope_weight, transdata,
wait_for_kv_layer_from_connector) wait_for_kv_layer_from_connector)
from vllm_ascend.distributed.utils import all_gather_async from vllm_ascend.distributed.utils import all_gather_async
@@ -47,6 +48,9 @@ else:
AttentionBackend, AttentionCGSupport, MLAAttentionImpl) AttentionBackend, AttentionCGSupport, MLAAttentionImpl)
# isort: on # isort: on
# token count limits within bmm_transpose operator
BMM_TRANS_MAX_SUPPORTED_TOKENS = 1024
class AscendSFABackend(AttentionBackend): class AscendSFABackend(AttentionBackend):
@@ -99,7 +103,6 @@ class AscendSFAMetadata:
# |---------- context_len ----------| # |---------- context_len ----------|
# |-------------------- seq_len ---------------------| # |-------------------- seq_len ---------------------|
# |-- query_len ---| # |-- query_len ---|
has_prefill: bool
num_actual_tokens: int # Number of tokens excluding padding. num_actual_tokens: int # Number of tokens excluding padding.
slot_mapping: torch.Tensor slot_mapping: torch.Tensor
seq_lens: torch.Tensor seq_lens: torch.Tensor
@@ -196,15 +199,10 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
input_positions = common_attn_metadata.positions[: input_positions = common_attn_metadata.positions[:
num_input_tokens].long( num_input_tokens].long(
) )
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
has_prefill = any(query_lens_cpu > self.decode_threshold)
cum_query_lens = common_attn_metadata.query_start_loc[1:num_reqs + 1] cum_query_lens = common_attn_metadata.query_start_loc[1:num_reqs + 1]
seq_lens = common_attn_metadata.seq_lens[:num_reqs] seq_lens = common_attn_metadata.seq_lens[:num_reqs]
if has_prefill:
cos, sin = get_cos_and_sin_mla(input_positions)
else:
cos, sin = get_cos_and_sin_mla(input_positions, True) cos, sin = get_cos_and_sin_mla(input_positions, True)
sfa_cp_context = None sfa_cp_context = None
@@ -285,7 +283,6 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
) )
return self.metadata_cls( # type: ignore return self.metadata_cls( # type: ignore
has_prefill=has_prefill,
num_input_tokens=common_attn_metadata.num_input_tokens, num_input_tokens=common_attn_metadata.num_input_tokens,
num_actual_tokens=num_actual_tokens, num_actual_tokens=num_actual_tokens,
cum_query_lens=cum_query_lens, cum_query_lens=cum_query_lens,
@@ -368,7 +365,6 @@ class AscendSFAImpl(MLAAttentionImpl):
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tp_group().rank_in_group self.tp_rank = get_tp_group().rank_in_group
self.num_heads_per_rank = self.num_heads // self.tp_size
self.q_b_proj = kwargs['q_b_proj'] self.q_b_proj = kwargs['q_b_proj']
ascend_config = get_ascend_config() ascend_config = get_ascend_config()
@@ -469,21 +465,17 @@ class AscendSFAImpl(MLAAttentionImpl):
# if mlapo, W_UK_T can't trans nz # if mlapo, W_UK_T can't trans nz
self.W_UK_T = maybe_trans_nz(self.W_UK_T) self.W_UK_T = maybe_trans_nz(self.W_UK_T)
def _v_up_proj(self, x, has_prefill: bool): def _v_up_proj(self, x):
# TODO(zzzzwwjj): We should not judge by whether `has_prefill` or not. num_input_tokens, _, _ = x.shape
# The true criteria for judgment is tensorA's shape[0] <= 1024 (num_tokens <= 1024).
# This is a bug in the previous code.
if x.dtype in [torch.float16, torch.bfloat16] \ if x.dtype in [torch.float16, torch.bfloat16] \
and hasattr(torch.ops._C_ascend, "batch_matmul_transpose") \ and hasattr(torch.ops._C_ascend, "batch_matmul_transpose") \
and not self.enable_sfa_cp \ and num_input_tokens <= BMM_TRANS_MAX_SUPPORTED_TOKENS:
and not has_prefill: x = x.view(-1, self.local_num_heads, self.kv_lora_rank)
x = x.view(-1, self.num_heads, self.kv_lora_rank) res = torch.empty((num_input_tokens, self.local_num_heads, self.v_head_dim),
b, _, _ = x.shape
res = torch.empty((b, self.num_heads, self.v_head_dim),
dtype=x.dtype, dtype=x.dtype,
device=x.device) device=x.device)
torch.ops._C_ascend.batch_matmul_transpose(x, self.W_UV, res) torch.ops._C_ascend.batch_matmul_transpose(x, self.W_UV, res)
x = res.reshape(-1, self.num_heads * self.v_head_dim) x = res.reshape(-1, self.local_num_heads * self.v_head_dim)
else: else:
# Convert from (B, N, L) to (N, B, L) # Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.local_num_heads, x = x.view(-1, self.local_num_heads,
@@ -654,10 +646,9 @@ class AscendSFAImpl(MLAAttentionImpl):
# On KV consumers (decode-only) MLAPO uses the transformed weights built above; # On KV consumers (decode-only) MLAPO uses the transformed weights built above;
# the original fused_qkv_a_proj/q_proj weights and quant params are no longer # the original fused_qkv_a_proj/q_proj weights and quant params are no longer
# referenced, so drop them to save memory. # referenced, so drop them to save memory.
ascend_config = get_ascend_config()
if self.vllm_config.kv_transfer_config is not None and \ if self.vllm_config.kv_transfer_config is not None and \
self.vllm_config.kv_transfer_config.is_kv_consumer and \ self.vllm_config.kv_transfer_config.is_kv_consumer and \
ascend_config.recompute_scheduler_enable: self.vllm_config.scheduler_config.max_num_batched_tokens <= MLAPO_MAX_SUPPORTED_TOKENS:
self.fused_qkv_a_proj.weight = None self.fused_qkv_a_proj.weight = None
self.fused_qkv_a_proj.deq_scale = None self.fused_qkv_a_proj.deq_scale = None
self.fused_qkv_a_proj.quant_bias = None self.fused_qkv_a_proj.quant_bias = None
@@ -745,7 +736,6 @@ class AscendSFAImpl(MLAAttentionImpl):
reach_layer_for_shard_weight_series(layer) reach_layer_for_shard_weight_series(layer)
return output.fill_(0) return output.fill_(0)
has_prefill = attn_metadata.has_prefill
cos = attn_metadata.cos cos = attn_metadata.cos
sin = attn_metadata.sin sin = attn_metadata.sin
actual_seq_lengths_query = attn_metadata.cum_query_lens actual_seq_lengths_query = attn_metadata.cum_query_lens
@@ -753,17 +743,16 @@ class AscendSFAImpl(MLAAttentionImpl):
if self.enable_sfa_cp: if self.enable_sfa_cp:
need_gather_q_kv = False need_gather_q_kv = False
# Inputs and outputs may be padded for CUDA graphs # Inputs and outputs may be padded for CUDA graphs
num_input_tokens = attn_metadata.num_input_tokens
output_padded = output output_padded = output
# TODO(zzzzwwjj): In sfa, prefill and decode have the same calculation formula, if self.enable_mlapo and num_input_tokens <= MLAPO_MAX_SUPPORTED_TOKENS:
# so `has_prefill` here is not necessary.
if self.enable_mlapo and not has_prefill:
hidden_states, ql_nope, q_pe, q_c = self._sfa_preprocess_decode( hidden_states, ql_nope, q_pe, q_c = self._sfa_preprocess_decode(
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache, kv_cache=kv_cache,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
need_gather_q_kv=need_gather_q_kv, need_gather_q_kv=need_gather_q_kv,
num_input_tokens=attn_metadata.num_input_tokens, num_input_tokens=num_input_tokens,
) )
q, k = self.indexer_select_pre_process( q, k = self.indexer_select_pre_process(
x=hidden_states, x=hidden_states,
@@ -796,7 +785,6 @@ class AscendSFAImpl(MLAAttentionImpl):
sin=sin, sin=sin,
need_gather_q_kv=need_gather_q_kv) need_gather_q_kv=need_gather_q_kv)
if has_prefill:
wait_for_kv_layer_from_connector(layer_name) wait_for_kv_layer_from_connector(layer_name)
slot_mapping = attn_metadata.slot_mapping slot_mapping = attn_metadata.slot_mapping
@@ -875,12 +863,15 @@ class AscendSFAImpl(MLAAttentionImpl):
sparse_mode=3, sparse_mode=3,
) )
attn_output = self._v_up_proj(attn_output, has_prefill) attn_output = self._v_up_proj(attn_output)
maybe_npu_prefetch(inputs=self.o_proj.weight, maybe_npu_prefetch(inputs=self.o_proj.weight,
dependency=attn_output, dependency=attn_output,
max_size=MAX_O_PROJ_PREFETCH_SIZE, max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=self.enable_prefetch) enabled=self.enable_prefetch)
output[...] = self.o_proj(attn_output)[0] output[...] = self.o_proj(attn_output)[0]
maybe_save_kv_layer_to_connector(layer_name, list(kv_cache))
return output_padded return output_padded
def indexer_select_pre_process( def indexer_select_pre_process(