[Bugfix] Fix the input constraints checks for the mlapo and bmm_transpose operators (#5764)
### What this PR does / why we need it?
This PR fix the input constraints checks for the mlapo and bmm_transpose
operators.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
CI passed with new added/existing test.
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
### Perf
64K/3K,1P1D,bs=32
before this pr:
TPOT 29ms, TTFT 47s,TPS 606 token/s
after this pr:
TPOT 29ms, TTFT 48s,TPS 636 token/s
Signed-off-by: rjg-lyh <1318825571@qq.com>
This commit is contained in:
@@ -36,7 +36,6 @@ class TestAscendSFABackend(TestBase):
|
||||
class TestAscendSFAMetadata(TestBase):
|
||||
|
||||
def test_ascend_sfa_metadata_default(self):
|
||||
has_prefill = True
|
||||
num_actual_tokens = 100
|
||||
slot_mapping = torch.randn(100, 4, 1024)
|
||||
seq_lens = torch.tensor([30, 50])
|
||||
@@ -54,7 +53,6 @@ class TestAscendSFAMetadata(TestBase):
|
||||
attn_state = AscendAttentionState.ChunkedPrefill
|
||||
|
||||
metadata = AscendSFAMetadata(
|
||||
has_prefill=has_prefill,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
slot_mapping=slot_mapping,
|
||||
seq_lens=seq_lens,
|
||||
@@ -68,7 +66,6 @@ class TestAscendSFAMetadata(TestBase):
|
||||
attn_state=attn_state,
|
||||
)
|
||||
|
||||
self.assertEqual(metadata.has_prefill, has_prefill)
|
||||
self.assertEqual(metadata.num_actual_tokens, num_actual_tokens)
|
||||
self.assertIs(metadata.slot_mapping, slot_mapping)
|
||||
self.assertTrue(torch.equal(metadata.seq_lens, seq_lens))
|
||||
|
||||
@@ -57,6 +57,8 @@ else:
|
||||
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024
|
||||
BUILD_METADATA_STEP_PREFILL = 0
|
||||
BUILD_METADATA_STEP_DECODE = 1
|
||||
# token count limits within the mlapo operator
|
||||
MLAPO_MAX_SUPPORTED_TOKENS = 1024
|
||||
|
||||
|
||||
class AscendMLABackend(AttentionBackend):
|
||||
@@ -927,10 +929,9 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
# On KV consumers (decode-only) MLAPO uses the transformed weights built above;
|
||||
# the original fused_qkv_a_proj/q_proj weights and quant params are no longer
|
||||
# referenced, so drop them to save memory.
|
||||
ascend_config = get_ascend_config()
|
||||
if self.vllm_config.kv_transfer_config is not None and \
|
||||
self.vllm_config.kv_transfer_config.is_kv_consumer and \
|
||||
ascend_config.recompute_scheduler_enable:
|
||||
self.vllm_config.scheduler_config.max_num_batched_tokens <= MLAPO_MAX_SUPPORTED_TOKENS:
|
||||
self.fused_qkv_a_proj.weight = None
|
||||
self.fused_qkv_a_proj.deq_scale = None
|
||||
self.fused_qkv_a_proj.quant_bias = None
|
||||
@@ -1508,7 +1509,9 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
device=hidden_states.device)
|
||||
|
||||
# MLA Preprocess
|
||||
if self.enable_mlapo and not has_prefill:
|
||||
if self.enable_mlapo and \
|
||||
not has_prefill and \
|
||||
attn_metadata.num_decode_tokens <= MLAPO_MAX_SUPPORTED_TOKENS:
|
||||
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
hidden_states.contiguous(), need_gather_q_kv)
|
||||
decode_preprocess_res, prefill_preprocess_res = self._mla_preprocess_only_decode(
|
||||
|
||||
@@ -18,8 +18,9 @@ from vllm_ascend import envs
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE
|
||||
from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE, MLAPO_MAX_SUPPORTED_TOKENS
|
||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
maybe_save_kv_layer_to_connector,
|
||||
trans_rope_weight, transdata,
|
||||
wait_for_kv_layer_from_connector)
|
||||
from vllm_ascend.distributed.utils import all_gather_async
|
||||
@@ -47,6 +48,9 @@ else:
|
||||
AttentionBackend, AttentionCGSupport, MLAAttentionImpl)
|
||||
# isort: on
|
||||
|
||||
# token count limits within bmm_transpose operator
|
||||
BMM_TRANS_MAX_SUPPORTED_TOKENS = 1024
|
||||
|
||||
|
||||
class AscendSFABackend(AttentionBackend):
|
||||
|
||||
@@ -99,7 +103,6 @@ class AscendSFAMetadata:
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ---------------------|
|
||||
# |-- query_len ---|
|
||||
has_prefill: bool
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
slot_mapping: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
@@ -196,16 +199,11 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
|
||||
input_positions = common_attn_metadata.positions[:
|
||||
num_input_tokens].long(
|
||||
)
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
has_prefill = any(query_lens_cpu > self.decode_threshold)
|
||||
|
||||
cum_query_lens = common_attn_metadata.query_start_loc[1:num_reqs + 1]
|
||||
seq_lens = common_attn_metadata.seq_lens[:num_reqs]
|
||||
if has_prefill:
|
||||
cos, sin = get_cos_and_sin_mla(input_positions)
|
||||
else:
|
||||
cos, sin = get_cos_and_sin_mla(input_positions, True)
|
||||
|
||||
cos, sin = get_cos_and_sin_mla(input_positions, True)
|
||||
|
||||
sfa_cp_context = None
|
||||
if self.enable_sfa_cp:
|
||||
@@ -285,7 +283,6 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
|
||||
)
|
||||
|
||||
return self.metadata_cls( # type: ignore
|
||||
has_prefill=has_prefill,
|
||||
num_input_tokens=common_attn_metadata.num_input_tokens,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
cum_query_lens=cum_query_lens,
|
||||
@@ -368,7 +365,6 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tp_group().rank_in_group
|
||||
self.num_heads_per_rank = self.num_heads // self.tp_size
|
||||
self.q_b_proj = kwargs['q_b_proj']
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
@@ -469,21 +465,17 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
# if mlapo, W_UK_T can't trans nz
|
||||
self.W_UK_T = maybe_trans_nz(self.W_UK_T)
|
||||
|
||||
def _v_up_proj(self, x, has_prefill: bool):
|
||||
# TODO(zzzzwwjj): We should not judge by whether `has_prefill` or not.
|
||||
# The true criteria for judgment is tensorA's shape[0] <= 1024 (num_tokens <= 1024).
|
||||
# This is a bug in the previous code.
|
||||
def _v_up_proj(self, x):
|
||||
num_input_tokens, _, _ = x.shape
|
||||
if x.dtype in [torch.float16, torch.bfloat16] \
|
||||
and hasattr(torch.ops._C_ascend, "batch_matmul_transpose") \
|
||||
and not self.enable_sfa_cp \
|
||||
and not has_prefill:
|
||||
x = x.view(-1, self.num_heads, self.kv_lora_rank)
|
||||
b, _, _ = x.shape
|
||||
res = torch.empty((b, self.num_heads, self.v_head_dim),
|
||||
and num_input_tokens <= BMM_TRANS_MAX_SUPPORTED_TOKENS:
|
||||
x = x.view(-1, self.local_num_heads, self.kv_lora_rank)
|
||||
res = torch.empty((num_input_tokens, self.local_num_heads, self.v_head_dim),
|
||||
dtype=x.dtype,
|
||||
device=x.device)
|
||||
torch.ops._C_ascend.batch_matmul_transpose(x, self.W_UV, res)
|
||||
x = res.reshape(-1, self.num_heads * self.v_head_dim)
|
||||
x = res.reshape(-1, self.local_num_heads * self.v_head_dim)
|
||||
else:
|
||||
# Convert from (B, N, L) to (N, B, L)
|
||||
x = x.view(-1, self.local_num_heads,
|
||||
@@ -654,10 +646,9 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
# On KV consumers (decode-only) MLAPO uses the transformed weights built above;
|
||||
# the original fused_qkv_a_proj/q_proj weights and quant params are no longer
|
||||
# referenced, so drop them to save memory.
|
||||
ascend_config = get_ascend_config()
|
||||
if self.vllm_config.kv_transfer_config is not None and \
|
||||
self.vllm_config.kv_transfer_config.is_kv_consumer and \
|
||||
ascend_config.recompute_scheduler_enable:
|
||||
self.vllm_config.scheduler_config.max_num_batched_tokens <= MLAPO_MAX_SUPPORTED_TOKENS:
|
||||
self.fused_qkv_a_proj.weight = None
|
||||
self.fused_qkv_a_proj.deq_scale = None
|
||||
self.fused_qkv_a_proj.quant_bias = None
|
||||
@@ -745,7 +736,6 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
reach_layer_for_shard_weight_series(layer)
|
||||
return output.fill_(0)
|
||||
|
||||
has_prefill = attn_metadata.has_prefill
|
||||
cos = attn_metadata.cos
|
||||
sin = attn_metadata.sin
|
||||
actual_seq_lengths_query = attn_metadata.cum_query_lens
|
||||
@@ -753,17 +743,16 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
if self.enable_sfa_cp:
|
||||
need_gather_q_kv = False
|
||||
# Inputs and outputs may be padded for CUDA graphs
|
||||
num_input_tokens = attn_metadata.num_input_tokens
|
||||
output_padded = output
|
||||
|
||||
# TODO(zzzzwwjj): In sfa, prefill and decode have the same calculation formula,
|
||||
# so `has_prefill` here is not necessary.
|
||||
if self.enable_mlapo and not has_prefill:
|
||||
if self.enable_mlapo and num_input_tokens <= MLAPO_MAX_SUPPORTED_TOKENS:
|
||||
hidden_states, ql_nope, q_pe, q_c = self._sfa_preprocess_decode(
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
need_gather_q_kv=need_gather_q_kv,
|
||||
num_input_tokens=attn_metadata.num_input_tokens,
|
||||
num_input_tokens=num_input_tokens,
|
||||
)
|
||||
q, k = self.indexer_select_pre_process(
|
||||
x=hidden_states,
|
||||
@@ -796,8 +785,7 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
sin=sin,
|
||||
need_gather_q_kv=need_gather_q_kv)
|
||||
|
||||
if has_prefill:
|
||||
wait_for_kv_layer_from_connector(layer_name)
|
||||
wait_for_kv_layer_from_connector(layer_name)
|
||||
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
if self.enable_sfa_cp:
|
||||
@@ -875,12 +863,15 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
sparse_mode=3,
|
||||
)
|
||||
|
||||
attn_output = self._v_up_proj(attn_output, has_prefill)
|
||||
attn_output = self._v_up_proj(attn_output)
|
||||
maybe_npu_prefetch(inputs=self.o_proj.weight,
|
||||
dependency=attn_output,
|
||||
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
||||
enabled=self.enable_prefetch)
|
||||
output[...] = self.o_proj(attn_output)[0]
|
||||
|
||||
maybe_save_kv_layer_to_connector(layer_name, list(kv_cache))
|
||||
|
||||
return output_padded
|
||||
|
||||
def indexer_select_pre_process(
|
||||
|
||||
Reference in New Issue
Block a user