[Bugfix] Fix the input constraints checks for the mlapo and bmm_transpose operators (#5764)
### What this PR does / why we need it?
This PR fix the input constraints checks for the mlapo and bmm_transpose
operators.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
CI passed with new added/existing test.
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
### Perf
64K/3K,1P1D,bs=32
before this pr:
TPOT 29ms, TTFT 47s,TPS 606 token/s
after this pr:
TPOT 29ms, TTFT 48s,TPS 636 token/s
Signed-off-by: rjg-lyh <1318825571@qq.com>
This commit is contained in:
@@ -36,7 +36,6 @@ class TestAscendSFABackend(TestBase):
|
|||||||
class TestAscendSFAMetadata(TestBase):
|
class TestAscendSFAMetadata(TestBase):
|
||||||
|
|
||||||
def test_ascend_sfa_metadata_default(self):
|
def test_ascend_sfa_metadata_default(self):
|
||||||
has_prefill = True
|
|
||||||
num_actual_tokens = 100
|
num_actual_tokens = 100
|
||||||
slot_mapping = torch.randn(100, 4, 1024)
|
slot_mapping = torch.randn(100, 4, 1024)
|
||||||
seq_lens = torch.tensor([30, 50])
|
seq_lens = torch.tensor([30, 50])
|
||||||
@@ -54,7 +53,6 @@ class TestAscendSFAMetadata(TestBase):
|
|||||||
attn_state = AscendAttentionState.ChunkedPrefill
|
attn_state = AscendAttentionState.ChunkedPrefill
|
||||||
|
|
||||||
metadata = AscendSFAMetadata(
|
metadata = AscendSFAMetadata(
|
||||||
has_prefill=has_prefill,
|
|
||||||
num_actual_tokens=num_actual_tokens,
|
num_actual_tokens=num_actual_tokens,
|
||||||
slot_mapping=slot_mapping,
|
slot_mapping=slot_mapping,
|
||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
@@ -68,7 +66,6 @@ class TestAscendSFAMetadata(TestBase):
|
|||||||
attn_state=attn_state,
|
attn_state=attn_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(metadata.has_prefill, has_prefill)
|
|
||||||
self.assertEqual(metadata.num_actual_tokens, num_actual_tokens)
|
self.assertEqual(metadata.num_actual_tokens, num_actual_tokens)
|
||||||
self.assertIs(metadata.slot_mapping, slot_mapping)
|
self.assertIs(metadata.slot_mapping, slot_mapping)
|
||||||
self.assertTrue(torch.equal(metadata.seq_lens, seq_lens))
|
self.assertTrue(torch.equal(metadata.seq_lens, seq_lens))
|
||||||
|
|||||||
@@ -57,6 +57,8 @@ else:
|
|||||||
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024
|
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024
|
||||||
BUILD_METADATA_STEP_PREFILL = 0
|
BUILD_METADATA_STEP_PREFILL = 0
|
||||||
BUILD_METADATA_STEP_DECODE = 1
|
BUILD_METADATA_STEP_DECODE = 1
|
||||||
|
# token count limits within the mlapo operator
|
||||||
|
MLAPO_MAX_SUPPORTED_TOKENS = 1024
|
||||||
|
|
||||||
|
|
||||||
class AscendMLABackend(AttentionBackend):
|
class AscendMLABackend(AttentionBackend):
|
||||||
@@ -927,10 +929,9 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
# On KV consumers (decode-only) MLAPO uses the transformed weights built above;
|
# On KV consumers (decode-only) MLAPO uses the transformed weights built above;
|
||||||
# the original fused_qkv_a_proj/q_proj weights and quant params are no longer
|
# the original fused_qkv_a_proj/q_proj weights and quant params are no longer
|
||||||
# referenced, so drop them to save memory.
|
# referenced, so drop them to save memory.
|
||||||
ascend_config = get_ascend_config()
|
|
||||||
if self.vllm_config.kv_transfer_config is not None and \
|
if self.vllm_config.kv_transfer_config is not None and \
|
||||||
self.vllm_config.kv_transfer_config.is_kv_consumer and \
|
self.vllm_config.kv_transfer_config.is_kv_consumer and \
|
||||||
ascend_config.recompute_scheduler_enable:
|
self.vllm_config.scheduler_config.max_num_batched_tokens <= MLAPO_MAX_SUPPORTED_TOKENS:
|
||||||
self.fused_qkv_a_proj.weight = None
|
self.fused_qkv_a_proj.weight = None
|
||||||
self.fused_qkv_a_proj.deq_scale = None
|
self.fused_qkv_a_proj.deq_scale = None
|
||||||
self.fused_qkv_a_proj.quant_bias = None
|
self.fused_qkv_a_proj.quant_bias = None
|
||||||
@@ -1508,7 +1509,9 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
device=hidden_states.device)
|
device=hidden_states.device)
|
||||||
|
|
||||||
# MLA Preprocess
|
# MLA Preprocess
|
||||||
if self.enable_mlapo and not has_prefill:
|
if self.enable_mlapo and \
|
||||||
|
not has_prefill and \
|
||||||
|
attn_metadata.num_decode_tokens <= MLAPO_MAX_SUPPORTED_TOKENS:
|
||||||
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||||
hidden_states.contiguous(), need_gather_q_kv)
|
hidden_states.contiguous(), need_gather_q_kv)
|
||||||
decode_preprocess_res, prefill_preprocess_res = self._mla_preprocess_only_decode(
|
decode_preprocess_res, prefill_preprocess_res = self._mla_preprocess_only_decode(
|
||||||
|
|||||||
@@ -18,8 +18,9 @@ from vllm_ascend import envs
|
|||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||||
from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE
|
from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE, MLAPO_MAX_SUPPORTED_TOKENS
|
||||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||||
|
maybe_save_kv_layer_to_connector,
|
||||||
trans_rope_weight, transdata,
|
trans_rope_weight, transdata,
|
||||||
wait_for_kv_layer_from_connector)
|
wait_for_kv_layer_from_connector)
|
||||||
from vllm_ascend.distributed.utils import all_gather_async
|
from vllm_ascend.distributed.utils import all_gather_async
|
||||||
@@ -47,6 +48,9 @@ else:
|
|||||||
AttentionBackend, AttentionCGSupport, MLAAttentionImpl)
|
AttentionBackend, AttentionCGSupport, MLAAttentionImpl)
|
||||||
# isort: on
|
# isort: on
|
||||||
|
|
||||||
|
# token count limits within bmm_transpose operator
|
||||||
|
BMM_TRANS_MAX_SUPPORTED_TOKENS = 1024
|
||||||
|
|
||||||
|
|
||||||
class AscendSFABackend(AttentionBackend):
|
class AscendSFABackend(AttentionBackend):
|
||||||
|
|
||||||
@@ -99,7 +103,6 @@ class AscendSFAMetadata:
|
|||||||
# |---------- context_len ----------|
|
# |---------- context_len ----------|
|
||||||
# |-------------------- seq_len ---------------------|
|
# |-------------------- seq_len ---------------------|
|
||||||
# |-- query_len ---|
|
# |-- query_len ---|
|
||||||
has_prefill: bool
|
|
||||||
num_actual_tokens: int # Number of tokens excluding padding.
|
num_actual_tokens: int # Number of tokens excluding padding.
|
||||||
slot_mapping: torch.Tensor
|
slot_mapping: torch.Tensor
|
||||||
seq_lens: torch.Tensor
|
seq_lens: torch.Tensor
|
||||||
@@ -196,15 +199,10 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
|
|||||||
input_positions = common_attn_metadata.positions[:
|
input_positions = common_attn_metadata.positions[:
|
||||||
num_input_tokens].long(
|
num_input_tokens].long(
|
||||||
)
|
)
|
||||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
|
||||||
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
|
||||||
has_prefill = any(query_lens_cpu > self.decode_threshold)
|
|
||||||
|
|
||||||
cum_query_lens = common_attn_metadata.query_start_loc[1:num_reqs + 1]
|
cum_query_lens = common_attn_metadata.query_start_loc[1:num_reqs + 1]
|
||||||
seq_lens = common_attn_metadata.seq_lens[:num_reqs]
|
seq_lens = common_attn_metadata.seq_lens[:num_reqs]
|
||||||
if has_prefill:
|
|
||||||
cos, sin = get_cos_and_sin_mla(input_positions)
|
|
||||||
else:
|
|
||||||
cos, sin = get_cos_and_sin_mla(input_positions, True)
|
cos, sin = get_cos_and_sin_mla(input_positions, True)
|
||||||
|
|
||||||
sfa_cp_context = None
|
sfa_cp_context = None
|
||||||
@@ -285,7 +283,6 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return self.metadata_cls( # type: ignore
|
return self.metadata_cls( # type: ignore
|
||||||
has_prefill=has_prefill,
|
|
||||||
num_input_tokens=common_attn_metadata.num_input_tokens,
|
num_input_tokens=common_attn_metadata.num_input_tokens,
|
||||||
num_actual_tokens=num_actual_tokens,
|
num_actual_tokens=num_actual_tokens,
|
||||||
cum_query_lens=cum_query_lens,
|
cum_query_lens=cum_query_lens,
|
||||||
@@ -368,7 +365,6 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
self.tp_rank = get_tp_group().rank_in_group
|
self.tp_rank = get_tp_group().rank_in_group
|
||||||
self.num_heads_per_rank = self.num_heads // self.tp_size
|
|
||||||
self.q_b_proj = kwargs['q_b_proj']
|
self.q_b_proj = kwargs['q_b_proj']
|
||||||
|
|
||||||
ascend_config = get_ascend_config()
|
ascend_config = get_ascend_config()
|
||||||
@@ -469,21 +465,17 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
# if mlapo, W_UK_T can't trans nz
|
# if mlapo, W_UK_T can't trans nz
|
||||||
self.W_UK_T = maybe_trans_nz(self.W_UK_T)
|
self.W_UK_T = maybe_trans_nz(self.W_UK_T)
|
||||||
|
|
||||||
def _v_up_proj(self, x, has_prefill: bool):
|
def _v_up_proj(self, x):
|
||||||
# TODO(zzzzwwjj): We should not judge by whether `has_prefill` or not.
|
num_input_tokens, _, _ = x.shape
|
||||||
# The true criteria for judgment is tensorA's shape[0] <= 1024 (num_tokens <= 1024).
|
|
||||||
# This is a bug in the previous code.
|
|
||||||
if x.dtype in [torch.float16, torch.bfloat16] \
|
if x.dtype in [torch.float16, torch.bfloat16] \
|
||||||
and hasattr(torch.ops._C_ascend, "batch_matmul_transpose") \
|
and hasattr(torch.ops._C_ascend, "batch_matmul_transpose") \
|
||||||
and not self.enable_sfa_cp \
|
and num_input_tokens <= BMM_TRANS_MAX_SUPPORTED_TOKENS:
|
||||||
and not has_prefill:
|
x = x.view(-1, self.local_num_heads, self.kv_lora_rank)
|
||||||
x = x.view(-1, self.num_heads, self.kv_lora_rank)
|
res = torch.empty((num_input_tokens, self.local_num_heads, self.v_head_dim),
|
||||||
b, _, _ = x.shape
|
|
||||||
res = torch.empty((b, self.num_heads, self.v_head_dim),
|
|
||||||
dtype=x.dtype,
|
dtype=x.dtype,
|
||||||
device=x.device)
|
device=x.device)
|
||||||
torch.ops._C_ascend.batch_matmul_transpose(x, self.W_UV, res)
|
torch.ops._C_ascend.batch_matmul_transpose(x, self.W_UV, res)
|
||||||
x = res.reshape(-1, self.num_heads * self.v_head_dim)
|
x = res.reshape(-1, self.local_num_heads * self.v_head_dim)
|
||||||
else:
|
else:
|
||||||
# Convert from (B, N, L) to (N, B, L)
|
# Convert from (B, N, L) to (N, B, L)
|
||||||
x = x.view(-1, self.local_num_heads,
|
x = x.view(-1, self.local_num_heads,
|
||||||
@@ -654,10 +646,9 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
# On KV consumers (decode-only) MLAPO uses the transformed weights built above;
|
# On KV consumers (decode-only) MLAPO uses the transformed weights built above;
|
||||||
# the original fused_qkv_a_proj/q_proj weights and quant params are no longer
|
# the original fused_qkv_a_proj/q_proj weights and quant params are no longer
|
||||||
# referenced, so drop them to save memory.
|
# referenced, so drop them to save memory.
|
||||||
ascend_config = get_ascend_config()
|
|
||||||
if self.vllm_config.kv_transfer_config is not None and \
|
if self.vllm_config.kv_transfer_config is not None and \
|
||||||
self.vllm_config.kv_transfer_config.is_kv_consumer and \
|
self.vllm_config.kv_transfer_config.is_kv_consumer and \
|
||||||
ascend_config.recompute_scheduler_enable:
|
self.vllm_config.scheduler_config.max_num_batched_tokens <= MLAPO_MAX_SUPPORTED_TOKENS:
|
||||||
self.fused_qkv_a_proj.weight = None
|
self.fused_qkv_a_proj.weight = None
|
||||||
self.fused_qkv_a_proj.deq_scale = None
|
self.fused_qkv_a_proj.deq_scale = None
|
||||||
self.fused_qkv_a_proj.quant_bias = None
|
self.fused_qkv_a_proj.quant_bias = None
|
||||||
@@ -745,7 +736,6 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
reach_layer_for_shard_weight_series(layer)
|
reach_layer_for_shard_weight_series(layer)
|
||||||
return output.fill_(0)
|
return output.fill_(0)
|
||||||
|
|
||||||
has_prefill = attn_metadata.has_prefill
|
|
||||||
cos = attn_metadata.cos
|
cos = attn_metadata.cos
|
||||||
sin = attn_metadata.sin
|
sin = attn_metadata.sin
|
||||||
actual_seq_lengths_query = attn_metadata.cum_query_lens
|
actual_seq_lengths_query = attn_metadata.cum_query_lens
|
||||||
@@ -753,17 +743,16 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
if self.enable_sfa_cp:
|
if self.enable_sfa_cp:
|
||||||
need_gather_q_kv = False
|
need_gather_q_kv = False
|
||||||
# Inputs and outputs may be padded for CUDA graphs
|
# Inputs and outputs may be padded for CUDA graphs
|
||||||
|
num_input_tokens = attn_metadata.num_input_tokens
|
||||||
output_padded = output
|
output_padded = output
|
||||||
|
|
||||||
# TODO(zzzzwwjj): In sfa, prefill and decode have the same calculation formula,
|
if self.enable_mlapo and num_input_tokens <= MLAPO_MAX_SUPPORTED_TOKENS:
|
||||||
# so `has_prefill` here is not necessary.
|
|
||||||
if self.enable_mlapo and not has_prefill:
|
|
||||||
hidden_states, ql_nope, q_pe, q_c = self._sfa_preprocess_decode(
|
hidden_states, ql_nope, q_pe, q_c = self._sfa_preprocess_decode(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
attn_metadata=attn_metadata,
|
attn_metadata=attn_metadata,
|
||||||
need_gather_q_kv=need_gather_q_kv,
|
need_gather_q_kv=need_gather_q_kv,
|
||||||
num_input_tokens=attn_metadata.num_input_tokens,
|
num_input_tokens=num_input_tokens,
|
||||||
)
|
)
|
||||||
q, k = self.indexer_select_pre_process(
|
q, k = self.indexer_select_pre_process(
|
||||||
x=hidden_states,
|
x=hidden_states,
|
||||||
@@ -796,7 +785,6 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
sin=sin,
|
sin=sin,
|
||||||
need_gather_q_kv=need_gather_q_kv)
|
need_gather_q_kv=need_gather_q_kv)
|
||||||
|
|
||||||
if has_prefill:
|
|
||||||
wait_for_kv_layer_from_connector(layer_name)
|
wait_for_kv_layer_from_connector(layer_name)
|
||||||
|
|
||||||
slot_mapping = attn_metadata.slot_mapping
|
slot_mapping = attn_metadata.slot_mapping
|
||||||
@@ -875,12 +863,15 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
sparse_mode=3,
|
sparse_mode=3,
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output = self._v_up_proj(attn_output, has_prefill)
|
attn_output = self._v_up_proj(attn_output)
|
||||||
maybe_npu_prefetch(inputs=self.o_proj.weight,
|
maybe_npu_prefetch(inputs=self.o_proj.weight,
|
||||||
dependency=attn_output,
|
dependency=attn_output,
|
||||||
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
||||||
enabled=self.enable_prefetch)
|
enabled=self.enable_prefetch)
|
||||||
output[...] = self.o_proj(attn_output)[0]
|
output[...] = self.o_proj(attn_output)[0]
|
||||||
|
|
||||||
|
maybe_save_kv_layer_to_connector(layer_name, list(kv_cache))
|
||||||
|
|
||||||
return output_padded
|
return output_padded
|
||||||
|
|
||||||
def indexer_select_pre_process(
|
def indexer_select_pre_process(
|
||||||
|
|||||||
Reference in New Issue
Block a user