[feature] support pcp + mtp in full graph (#4572)
1. support pcp + mtp in full graph
2. pcp/dcp related mtp bugfix
3. support pcp + mtpx
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com>
This commit is contained in:
@@ -32,6 +32,7 @@ from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
|
||||
update_mla_attn_dcp_pcp_params,
|
||||
update_mla_attn_params)
|
||||
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
|
||||
from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable,
|
||||
@@ -98,6 +99,7 @@ class MtpProposer(Proposer):
|
||||
self.pcp_size = self.runner.pcp_size
|
||||
self.dcp_size = self.runner.dcp_size
|
||||
self.pcp_rank = self.runner.pcp_rank
|
||||
self.dcp_rank = self.runner.dcp_rank
|
||||
|
||||
self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None
|
||||
self.draft_indexer_metadata_builder: Optional[
|
||||
@@ -267,6 +269,13 @@ class MtpProposer(Proposer):
|
||||
attn_state=self.runner.attn_state,
|
||||
decode_token_per_req=self.runner.decode_token_per_req,
|
||||
)
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
# update long_seq related params and flatten block_table
|
||||
common_attn_metadata.prefill_context_parallel_metadata = \
|
||||
self.runner.long_seq_metadata
|
||||
common_attn_metadata.block_table_tensor = \
|
||||
self.runner.input_batch.block_table[0].get_device_tensor()[
|
||||
:num_reqs * self.decode_threshold]
|
||||
|
||||
builder = self.runner.attn_groups[0][0].get_metadata_builder()
|
||||
attn_metadata_mtp = builder.build_for_graph_capture(
|
||||
@@ -310,9 +319,15 @@ class MtpProposer(Proposer):
|
||||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and \
|
||||
not forward_context.capturing:
|
||||
if self.vllm_config.model_config.use_mla and not self.use_sparse:
|
||||
update_mla_attn_params(
|
||||
self.update_stream, forward_context, num_tokens,
|
||||
self.vllm_config.speculative_config)
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
update_mla_attn_dcp_pcp_params(
|
||||
self.update_stream, forward_context,
|
||||
num_tokens)
|
||||
else:
|
||||
update_mla_attn_params(
|
||||
self.update_stream, forward_context,
|
||||
num_tokens,
|
||||
self.vllm_config.speculative_config)
|
||||
if self.enable_shared_expert_dp:
|
||||
positions = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
positions, True)
|
||||
@@ -364,11 +379,11 @@ class MtpProposer(Proposer):
|
||||
valid_sampled_tokens_count)
|
||||
|
||||
req_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
||||
if self.pcp_size > 1:
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
long_seq_metadata = self.runner.long_seq_metadata
|
||||
input_ids_pcp_full = self.runner.input_ids_pcp_full
|
||||
query_start_loc_pcp_full = self.runner.query_start_loc_pcp_full
|
||||
query_start_loc_pcp_full_cpu = self.runner.query_start_loc_pcp_full_cpu
|
||||
input_ids_pcp_full = self.runner.input_ids_pcp_full.gpu
|
||||
query_start_loc_pcp_full = self.runner.query_start_loc_pcp_full.gpu
|
||||
query_start_loc_pcp_full_cpu = self.runner.query_start_loc_pcp_full.cpu
|
||||
num_reqs = self.runner.input_batch.num_reqs
|
||||
ori_query_lens = query_start_loc_pcp_full_cpu[1:num_reqs+1] - \
|
||||
query_start_loc_pcp_full_cpu[:num_reqs]
|
||||
@@ -396,12 +411,11 @@ class MtpProposer(Proposer):
|
||||
target_hidden_states = hidden_states[:num_scheduled_tokens]
|
||||
else:
|
||||
if self.pcp_size > 1:
|
||||
common_attn_metadata.query_start_loc_cpu = \
|
||||
common_attn_metadata.query_start_loc_cpu[:num_reqs + 1] = \
|
||||
query_start_loc_pcp_full_cpu[:num_reqs + 1]
|
||||
common_attn_metadata.query_start_loc = \
|
||||
common_attn_metadata.query_start_loc[:num_reqs + 1] = \
|
||||
query_start_loc_pcp_full[:num_reqs + 1]
|
||||
if self.speculative_config.disable_padded_drafter_batch:
|
||||
# NOTE: Currently, MTP-fullgraph is incompatibility with pcp
|
||||
token_indices_to_sample = None
|
||||
common_attn_metadata, token_indices =\
|
||||
self._prepare_inputs(
|
||||
@@ -630,15 +644,18 @@ class MtpProposer(Proposer):
|
||||
self.input_ids[last_token_indices] = next_token_ids
|
||||
|
||||
# update pcp related params
|
||||
if self.pcp_size > 1:
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
assert long_seq_metadata is not None
|
||||
common_attn_metadata.prefill_context_parallel_metadata = long_seq_metadata
|
||||
ori_last_token_indices = last_token_indices.clone()
|
||||
query_lens_d = self.runner.query_lens[:num_decode_reqs]
|
||||
if self.pcp_size > 1:
|
||||
# 1. preprocess decode/prefill input_ids & target_hidden_states
|
||||
# decode input_ids: keep unchanged
|
||||
# decode target_hidden_states: remove padding
|
||||
# prefill input_ids: add padding and pcp split
|
||||
# prefill target_hidden_states: pcp split
|
||||
num_tokens_d = num_decode_reqs * self.decode_threshold
|
||||
num_tokens_d = query_lens_d.sum().item()
|
||||
num_tokens_d_padded = num_tokens_d * self.pcp_size
|
||||
input_ids_d = self.input_ids[:num_tokens_d]
|
||||
input_ids_p = self.input_ids[num_tokens_d:num_tokens]
|
||||
@@ -646,12 +663,17 @@ class MtpProposer(Proposer):
|
||||
target_hidden_states[:num_tokens_d_padded]
|
||||
if num_tokens_d:
|
||||
# remove padding (from pcp all-gather) in decode part
|
||||
target_hidden_states_d = target_hidden_states_d_padded.reshape(
|
||||
[
|
||||
num_decode_reqs, self.decode_threshold * self.pcp_size,
|
||||
-1
|
||||
])[:, :self.decode_threshold, :].reshape(
|
||||
[num_tokens_d, -1])
|
||||
mask_start_loc = torch.cat([
|
||||
torch.tensor([0], dtype=torch.int32),
|
||||
torch.cumsum(query_lens_d * self.pcp_size, dim=0)[:-1]
|
||||
])
|
||||
mask_len = query_lens_d
|
||||
mask = []
|
||||
for req_id in range(num_decode_reqs):
|
||||
mask += list(
|
||||
range(mask_start_loc[req_id],
|
||||
mask_start_loc[req_id] + mask_len[req_id]))
|
||||
target_hidden_states_d = target_hidden_states_d_padded[mask]
|
||||
else:
|
||||
target_hidden_states_d = target_hidden_states_d_padded
|
||||
target_hidden_states_p = target_hidden_states[num_tokens_d_padded:]
|
||||
@@ -670,25 +692,26 @@ class MtpProposer(Proposer):
|
||||
torch.cat([input_ids_d, input_ids_p], dim=0))
|
||||
target_hidden_states = torch.cat(
|
||||
[target_hidden_states_d, target_hidden_states_p], dim=0)
|
||||
# 2. update attn_metadata params that may be influenced by pcp
|
||||
common_attn_metadata.num_actual_tokens = num_tokens
|
||||
common_attn_metadata.max_query_len = max(self.decode_threshold,
|
||||
max_query_len_p)
|
||||
common_attn_metadata.seq_lens[num_decode_reqs:] = seq_lens_p
|
||||
common_attn_metadata.seq_lens_cpu[num_decode_reqs:] = seq_lens_p
|
||||
query_start_loc_p = cu_num_tokens_p[1:] + \
|
||||
common_attn_metadata.query_start_loc[num_decode_reqs].item()
|
||||
common_attn_metadata.query_start_loc[num_decode_reqs + 1:] = \
|
||||
query_start_loc_p
|
||||
common_attn_metadata.query_start_loc_cpu[num_decode_reqs + 1:] = \
|
||||
query_start_loc_p
|
||||
# 3. update sample_indices according to main model
|
||||
# 2. update sample_indices according to main model
|
||||
if num_decode_reqs:
|
||||
last_token_indices[:num_decode_reqs] = \
|
||||
self.runner.logits_indices[last_token_indices[:num_decode_reqs]]
|
||||
if num_prefill_reqs:
|
||||
last_token_indices[-num_prefill_reqs:] = \
|
||||
self.runner.logits_indices[-num_prefill_reqs:]
|
||||
# 3. update attn_metadata params that may be influenced by pcp
|
||||
common_attn_metadata.num_actual_tokens = num_tokens
|
||||
common_attn_metadata.max_query_len = max(
|
||||
self.decode_threshold, max_query_len_p)
|
||||
common_attn_metadata.seq_lens[-num_prefill_reqs:] = seq_lens_p
|
||||
common_attn_metadata.seq_lens_cpu[
|
||||
-num_prefill_reqs:] = seq_lens_p
|
||||
query_start_loc_p = cu_num_tokens_p[1:] + \
|
||||
common_attn_metadata.query_start_loc[num_decode_reqs].item()
|
||||
common_attn_metadata.query_start_loc[-num_prefill_reqs:] = \
|
||||
query_start_loc_p
|
||||
common_attn_metadata.query_start_loc_cpu[-num_prefill_reqs:] = \
|
||||
query_start_loc_p
|
||||
|
||||
assert self.runner is not None
|
||||
|
||||
@@ -796,10 +819,15 @@ class MtpProposer(Proposer):
|
||||
forward_context = get_forward_context()
|
||||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
||||
if self.vllm_config.model_config.use_mla and not self.use_sparse:
|
||||
update_mla_attn_params(
|
||||
self.update_stream, forward_context,
|
||||
num_input_tokens,
|
||||
self.vllm_config.speculative_config)
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
update_mla_attn_dcp_pcp_params(
|
||||
self.update_stream, forward_context,
|
||||
num_input_tokens)
|
||||
else:
|
||||
update_mla_attn_params(
|
||||
self.update_stream, forward_context,
|
||||
num_input_tokens,
|
||||
self.vllm_config.speculative_config)
|
||||
|
||||
if self.enable_shared_expert_dp:
|
||||
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
@@ -814,7 +842,9 @@ class MtpProposer(Proposer):
|
||||
last_token_indices,
|
||||
(0, max_num_reqs_across_dp - num_indices))
|
||||
|
||||
if self.pcp_size > 1:
|
||||
if self.pcp_size > 1 and step == 0:
|
||||
# remove graph padding before all_gather
|
||||
hidden_states = hidden_states[:num_tokens]
|
||||
hidden_states = get_pcp_group().all_gather(hidden_states, 0)
|
||||
hidden_states = torch.index_select(
|
||||
hidden_states, 0, self.runner.
|
||||
@@ -855,6 +885,51 @@ class MtpProposer(Proposer):
|
||||
last_token_indices = self.arange[:batch_size]
|
||||
if getattr(attn_metadata_i, "num_decode_tokens", 0):
|
||||
attn_metadata_i.num_decode_tokens = batch_size
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
positions = target_positions[ori_last_token_indices]
|
||||
# For pcp/dcp, tokens are split across different cp ranks,
|
||||
# so we can not simply update slot_mapping by += 1.
|
||||
# Instead, we pre-allocate mtp slot_mapping in model_runner
|
||||
# (_generate_pcp_mtp_input), and use updated slot_indices
|
||||
# to get corresponding slot_mapping in each step.
|
||||
num_reject_tokens = torch.tensor(
|
||||
self.runner.cu_num_tokens_pcp_full,
|
||||
dtype=torch.int32).to(
|
||||
self.device) - ori_last_token_indices - 1
|
||||
num_accept_tokens = \
|
||||
query_lens_d.to(self.device) - num_reject_tokens
|
||||
ori_seq_len = attn_metadata_i.seq_lens
|
||||
mtp_slot_mapping = self.runner.mtp_slot_pad
|
||||
|
||||
# slot_mapping index base offset:
|
||||
# scheduled tokens + pre-allocated mtp tokens + accepted tokens
|
||||
slot_idx_base = (
|
||||
torch.cat([
|
||||
torch.tensor(
|
||||
[0], dtype=torch.int32, device=self.device),
|
||||
(torch.cumsum(query_lens_d, dim=0)[:-1] *
|
||||
self.pcp_size).to(self.device)
|
||||
]) +
|
||||
torch.arange(num_decode_reqs, device=self.device) *
|
||||
(self.num_speculative_tokens - 1) * self.pcp_size +
|
||||
(num_accept_tokens - 1) * self.pcp_size)
|
||||
slot_indices_list = []
|
||||
for req_id in range(num_decode_reqs):
|
||||
slot_indices_list.append(
|
||||
torch.arange(slot_idx_base[req_id],
|
||||
slot_idx_base[req_id] + self.pcp_size,
|
||||
device=self.device))
|
||||
slot_indices = torch.cat(slot_indices_list, dim=0)
|
||||
|
||||
# fold block_table (restore it to original size before flattened)
|
||||
block_indices = torch.cat([
|
||||
torch.tensor([0], dtype=torch.int32),
|
||||
torch.cumsum(query_lens_d, dim=0)[:-1]
|
||||
])
|
||||
attn_metadata_i.decode.block_table[:batch_size] = \
|
||||
attn_metadata_i.decode.block_table[block_indices]
|
||||
attn_metadata_i.decode.block_table = \
|
||||
attn_metadata_i.decode.block_table[:batch_size]
|
||||
|
||||
input_ids = draft_token_ids_list[-1].int()
|
||||
positions += 1
|
||||
@@ -901,13 +976,40 @@ class MtpProposer(Proposer):
|
||||
# Otherwise, the KV cache will be inadvertently updated with the
|
||||
# padding tokens.
|
||||
slot_mapping += 1
|
||||
if self.pcp_size > 1:
|
||||
exceeds_max_model_len = exceeds_max_model_len.repeat_interleave(
|
||||
slot_mapping.size(0) // exceeds_max_model_len.size(0))
|
||||
slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID)
|
||||
|
||||
# copy inputs to buffer for cudagraph
|
||||
self.input_ids[:batch_size] = input_ids
|
||||
self.positions[:batch_size] = clamped_positions
|
||||
self.hidden_states[:hidden_states.shape[0]] = hidden_states
|
||||
attn_metadata_i.slot_mapping[:batch_size] = slot_mapping
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
# update local seq_len and batch_seq_mask
|
||||
num_computed_tokens_of_pcp_dcp = self.runner._get_cp_local_seq_lens(
|
||||
ori_seq_len + step + 1,
|
||||
self.pcp_size,
|
||||
self.dcp_size,
|
||||
self.runner.parallel_config.cp_kv_cache_interleave_size,
|
||||
)
|
||||
cp_seq_len = \
|
||||
num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank]
|
||||
batch_seq_mask = (cp_seq_len == 0)
|
||||
builder.batch_seq_mask_buf[:batch_seq_mask.shape[0]].copy_(
|
||||
batch_seq_mask, non_blocking=True)
|
||||
batch_seq_mask = builder.batch_seq_mask_buf[:batch_seq_mask.
|
||||
shape[0]]
|
||||
cp_seq_len = torch.where(cp_seq_len == 0, 1, cp_seq_len)
|
||||
attn_metadata_i.decode.cp_seq_len = cp_seq_len
|
||||
attn_metadata_i.decode.batch_seq_mask = batch_seq_mask
|
||||
# update slot_mapping
|
||||
slot_indices += self.pcp_size
|
||||
slot_mapping = mtp_slot_mapping[slot_indices]
|
||||
attn_metadata_i.slot_mapping[:batch_size *
|
||||
self.pcp_size] = slot_mapping
|
||||
else:
|
||||
attn_metadata_i.slot_mapping[:batch_size] = slot_mapping
|
||||
if self.speculative_config.disable_padded_drafter_batch:
|
||||
self.positions[batch_size:num_input_tokens] = 0
|
||||
self.input_ids[batch_size:num_input_tokens] = 0
|
||||
|
||||
Reference in New Issue
Block a user