[feature] support pcp + mtp in full graph (#4572)

1. support pcp + mtp in full graph
2. pcp/dcp related mtp bugfix
3. support pcp + mtpx

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com>
This commit is contained in:
zhangsicheng5
2025-12-22 16:13:39 +08:00
committed by GitHub
parent 12d581605b
commit 78aa7f2693
10 changed files with 478 additions and 94 deletions

View File

@@ -27,6 +27,7 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
split_decodes_and_prefills,
wait_for_kv_layer_from_connector)
from vllm_ascend.compilation.acl_graph import (get_graph_params,
get_mtp_graph_params,
update_graph_params_workspaces)
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
from vllm_ascend.ops.shared_weight_layer import (
@@ -92,6 +93,10 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
num_actual_tokens_pcp_padded = long_seq_metadata.num_actual_tokens_pcp_padded
if num_actual_tokens_pcp_padded is None:
num_actual_tokens_pcp_padded = num_actual_tokens
# In dcp only spec decode graph padding case,
# num_actual_tokens_pcp_padded may be less than num_actual_tokens
num_actual_tokens_pcp_padded = max(num_actual_tokens_pcp_padded,
num_actual_tokens)
num_computed_tokens_of_pcp_dcp = long_seq_metadata.num_computed_tokens_of_pcp_dcp
assert num_computed_tokens_of_pcp_dcp is not None
@@ -113,15 +118,6 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
common_attn_metadata.block_table_tensor[:graph_pad_size])
else:
block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
if self.pcp_size > 1:
num_decodes_flatten = num_decodes * self.decode_threshold
block_table = common_attn_metadata.block_table_tensor[:
num_decodes_flatten
+
num_prefills]
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
slot_mapping = common_attn_metadata.slot_mapping[:
num_actual_tokens_pcp_padded]
input_positions = common_attn_metadata.positions[:
@@ -144,6 +140,13 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
num_computed_tokens_cpu = (seq_lens - query_lens)
# For pcp + spec decode, we flatten seq_lens and block_table
# to avoid irregular spec_attn_mask shape
num_decodes_flatten = query_lens[:num_decodes].sum().item()
block_table = common_attn_metadata.block_table_tensor[:
num_decodes_flatten
+ num_prefills]
prefill_metadata = None
chunked_context_metadata = None
if num_prefills > 0:
@@ -201,7 +204,7 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
dtype=torch.int32)
local_context_lens_allranks = torch.tensor(
num_computed_tokens_of_pcp_dcp[reqs_start:num_reqs]
num_computed_tokens_of_pcp_dcp[num_decodes_flatten:]
).reshape(-1, self.dcp_size * self.pcp_size)
# Note(qcs): The max local context lengths
# padded to `cp_local_block_size`.
@@ -280,9 +283,8 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
cos=cos,
pcp_metadata=pcp_metadata,
)
if self.pcp_size > 1:
prefill_metadata.block_table = block_table[
num_decodes_flatten:, ...]
prefill_metadata.block_table = \
block_table[num_decodes_flatten:, ...]
decode_metadata = None
if num_decodes > 0:
@@ -293,13 +295,7 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
max_seq_lens = seq_lens[:num_decodes].max().item()
seq_lens = seq_lens[:num_decodes]
input_positions = input_positions[:num_decode_tokens]
if self.pcp_size > 1:
# For pcp + spec decode, we flatten seq_lens and block_table
# to avoid irregular spec_attn_mask shape
block_table = block_table[:num_decodes_flatten, ...]
else:
block_table = block_table[:num_decodes, ...]
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
block_table = block_table[:num_decodes_flatten, ...]
# NOTE: Maybe this block_table change can be removed when graph_pad_size > 1.
if graph_pad_size > num_decodes and \
self.speculative_config.disable_padded_drafter_batch:
@@ -308,8 +304,7 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
# [bs, pcp_size, dcp_size]
num_computed_tokens_of_cp_dcp_array = np.array(
num_computed_tokens_of_pcp_dcp)[:num_decodes *
self.decode_threshold]
num_computed_tokens_of_pcp_dcp)[:num_decodes_flatten]
cp_seq_len = num_computed_tokens_of_cp_dcp_array[:, self.pcp_rank,
self.dcp_rank]
@@ -1057,8 +1052,11 @@ class AscendMlaCPImpl(AscendMLAImpl):
"return_lse": True,
"calc_type": "calc_type_ring",
}
graph_params = get_graph_params()
forward_context: ForwardContext = get_forward_context()
if forward_context.is_mtp_model:
graph_params = get_mtp_graph_params()
else:
graph_params = get_graph_params()
if forward_context.capturing:
stream = torch_npu.npu.current_stream()
event = torch.npu.ExternalEvent()

View File

@@ -67,6 +67,12 @@ class AscendPrefillContextParallelMetadata:
pcp_prefill_mask: torch.Tensor = None
# original query_lens before pcp split
query_lens_pcp_full_cpu: torch.Tensor = None
# original max_query_len before pcp split
max_query_len_pcp_full: int = 0
@dataclass
class AscendCommonAttentionMetadata:
@@ -189,6 +195,8 @@ def split_decodes_and_prefills(
"""
Assuming a reordered batch, finds the boundary between prefill and decode
requests.
While pcp > 1, query_lens is split across pcp ranks, so we pass in the
original query_lens and max_query_len to distinguish prefills and decodes.
Args:
common_attn_metadata: AscendCommonAttentionMetadata object containing the
@@ -201,7 +209,13 @@ def split_decodes_and_prefills(
num_decode_tokens: The number of tokens in the decode requests.
num_prefill_tokens: The number of tokens in the prefill requests.
"""
max_query_len = common_attn_metadata.max_query_len
long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
query_lens_pcp_full = long_seq_metadata.query_lens_pcp_full_cpu \
if long_seq_metadata else None
max_query_len_pcp_full = long_seq_metadata.max_query_len_pcp_full \
if long_seq_metadata else 0
max_query_len = common_attn_metadata.max_query_len \
if max_query_len_pcp_full == 0 else max_query_len_pcp_full
num_reqs = common_attn_metadata.num_reqs
num_tokens = common_attn_metadata.num_actual_tokens
query_start_loc = common_attn_metadata.query_start_loc_cpu
@@ -209,7 +223,8 @@ def split_decodes_and_prefills(
if max_query_len <= decode_threshold:
return num_reqs, 0, num_tokens, 0
query_lens = query_start_loc[1:] - query_start_loc[:-1]
query_lens = (query_start_loc[1:] - query_start_loc[:-1]) \
if query_lens_pcp_full is None else query_lens_pcp_full
is_prefill = query_lens > decode_threshold
if not torch.any(is_prefill):
return num_reqs, 0, num_tokens, 0