[feature] support pcp + mtp in full graph (#4572)
1. support pcp + mtp in full graph
2. pcp/dcp related mtp bugfix
3. support pcp + mtpx
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com>
This commit is contained in:
@@ -27,6 +27,7 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
split_decodes_and_prefills,
|
||||
wait_for_kv_layer_from_connector)
|
||||
from vllm_ascend.compilation.acl_graph import (get_graph_params,
|
||||
get_mtp_graph_params,
|
||||
update_graph_params_workspaces)
|
||||
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
|
||||
from vllm_ascend.ops.shared_weight_layer import (
|
||||
@@ -92,6 +93,10 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
|
||||
num_actual_tokens_pcp_padded = long_seq_metadata.num_actual_tokens_pcp_padded
|
||||
if num_actual_tokens_pcp_padded is None:
|
||||
num_actual_tokens_pcp_padded = num_actual_tokens
|
||||
# In dcp only spec decode graph padding case,
|
||||
# num_actual_tokens_pcp_padded may be less than num_actual_tokens
|
||||
num_actual_tokens_pcp_padded = max(num_actual_tokens_pcp_padded,
|
||||
num_actual_tokens)
|
||||
num_computed_tokens_of_pcp_dcp = long_seq_metadata.num_computed_tokens_of_pcp_dcp
|
||||
assert num_computed_tokens_of_pcp_dcp is not None
|
||||
|
||||
@@ -113,15 +118,6 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
|
||||
common_attn_metadata.block_table_tensor[:graph_pad_size])
|
||||
else:
|
||||
block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
|
||||
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
|
||||
if self.pcp_size > 1:
|
||||
num_decodes_flatten = num_decodes * self.decode_threshold
|
||||
block_table = common_attn_metadata.block_table_tensor[:
|
||||
num_decodes_flatten
|
||||
+
|
||||
num_prefills]
|
||||
|
||||
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
|
||||
slot_mapping = common_attn_metadata.slot_mapping[:
|
||||
num_actual_tokens_pcp_padded]
|
||||
input_positions = common_attn_metadata.positions[:
|
||||
@@ -144,6 +140,13 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
|
||||
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
|
||||
num_computed_tokens_cpu = (seq_lens - query_lens)
|
||||
|
||||
# For pcp + spec decode, we flatten seq_lens and block_table
|
||||
# to avoid irregular spec_attn_mask shape
|
||||
num_decodes_flatten = query_lens[:num_decodes].sum().item()
|
||||
block_table = common_attn_metadata.block_table_tensor[:
|
||||
num_decodes_flatten
|
||||
+ num_prefills]
|
||||
|
||||
prefill_metadata = None
|
||||
chunked_context_metadata = None
|
||||
if num_prefills > 0:
|
||||
@@ -201,7 +204,7 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
|
||||
dtype=torch.int32)
|
||||
|
||||
local_context_lens_allranks = torch.tensor(
|
||||
num_computed_tokens_of_pcp_dcp[reqs_start:num_reqs]
|
||||
num_computed_tokens_of_pcp_dcp[num_decodes_flatten:]
|
||||
).reshape(-1, self.dcp_size * self.pcp_size)
|
||||
# Note(qcs): The max local context lengths
|
||||
# padded to `cp_local_block_size`.
|
||||
@@ -280,9 +283,8 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
|
||||
cos=cos,
|
||||
pcp_metadata=pcp_metadata,
|
||||
)
|
||||
if self.pcp_size > 1:
|
||||
prefill_metadata.block_table = block_table[
|
||||
num_decodes_flatten:, ...]
|
||||
prefill_metadata.block_table = \
|
||||
block_table[num_decodes_flatten:, ...]
|
||||
|
||||
decode_metadata = None
|
||||
if num_decodes > 0:
|
||||
@@ -293,13 +295,7 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
|
||||
max_seq_lens = seq_lens[:num_decodes].max().item()
|
||||
seq_lens = seq_lens[:num_decodes]
|
||||
input_positions = input_positions[:num_decode_tokens]
|
||||
if self.pcp_size > 1:
|
||||
# For pcp + spec decode, we flatten seq_lens and block_table
|
||||
# to avoid irregular spec_attn_mask shape
|
||||
block_table = block_table[:num_decodes_flatten, ...]
|
||||
else:
|
||||
block_table = block_table[:num_decodes, ...]
|
||||
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
|
||||
block_table = block_table[:num_decodes_flatten, ...]
|
||||
# NOTE: Maybe this block_table change can be removed when graph_pad_size > 1.
|
||||
if graph_pad_size > num_decodes and \
|
||||
self.speculative_config.disable_padded_drafter_batch:
|
||||
@@ -308,8 +304,7 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
|
||||
|
||||
# [bs, pcp_size, dcp_size]
|
||||
num_computed_tokens_of_cp_dcp_array = np.array(
|
||||
num_computed_tokens_of_pcp_dcp)[:num_decodes *
|
||||
self.decode_threshold]
|
||||
num_computed_tokens_of_pcp_dcp)[:num_decodes_flatten]
|
||||
|
||||
cp_seq_len = num_computed_tokens_of_cp_dcp_array[:, self.pcp_rank,
|
||||
self.dcp_rank]
|
||||
@@ -1057,8 +1052,11 @@ class AscendMlaCPImpl(AscendMLAImpl):
|
||||
"return_lse": True,
|
||||
"calc_type": "calc_type_ring",
|
||||
}
|
||||
graph_params = get_graph_params()
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
if forward_context.is_mtp_model:
|
||||
graph_params = get_mtp_graph_params()
|
||||
else:
|
||||
graph_params = get_graph_params()
|
||||
if forward_context.capturing:
|
||||
stream = torch_npu.npu.current_stream()
|
||||
event = torch.npu.ExternalEvent()
|
||||
|
||||
@@ -67,6 +67,12 @@ class AscendPrefillContextParallelMetadata:
|
||||
|
||||
pcp_prefill_mask: torch.Tensor = None
|
||||
|
||||
# original query_lens before pcp split
|
||||
query_lens_pcp_full_cpu: torch.Tensor = None
|
||||
|
||||
# original max_query_len before pcp split
|
||||
max_query_len_pcp_full: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendCommonAttentionMetadata:
|
||||
@@ -189,6 +195,8 @@ def split_decodes_and_prefills(
|
||||
"""
|
||||
Assuming a reordered batch, finds the boundary between prefill and decode
|
||||
requests.
|
||||
While pcp > 1, query_lens is split across pcp ranks, so we pass in the
|
||||
original query_lens and max_query_len to distinguish prefills and decodes.
|
||||
|
||||
Args:
|
||||
common_attn_metadata: AscendCommonAttentionMetadata object containing the
|
||||
@@ -201,7 +209,13 @@ def split_decodes_and_prefills(
|
||||
num_decode_tokens: The number of tokens in the decode requests.
|
||||
num_prefill_tokens: The number of tokens in the prefill requests.
|
||||
"""
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
|
||||
query_lens_pcp_full = long_seq_metadata.query_lens_pcp_full_cpu \
|
||||
if long_seq_metadata else None
|
||||
max_query_len_pcp_full = long_seq_metadata.max_query_len_pcp_full \
|
||||
if long_seq_metadata else 0
|
||||
max_query_len = common_attn_metadata.max_query_len \
|
||||
if max_query_len_pcp_full == 0 else max_query_len_pcp_full
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_tokens = common_attn_metadata.num_actual_tokens
|
||||
query_start_loc = common_attn_metadata.query_start_loc_cpu
|
||||
@@ -209,7 +223,8 @@ def split_decodes_and_prefills(
|
||||
if max_query_len <= decode_threshold:
|
||||
return num_reqs, 0, num_tokens, 0
|
||||
|
||||
query_lens = query_start_loc[1:] - query_start_loc[:-1]
|
||||
query_lens = (query_start_loc[1:] - query_start_loc[:-1]) \
|
||||
if query_lens_pcp_full is None else query_lens_pcp_full
|
||||
is_prefill = query_lens > decode_threshold
|
||||
if not torch.any(is_prefill):
|
||||
return num_reqs, 0, num_tokens, 0
|
||||
|
||||
Reference in New Issue
Block a user