[Refactor][EAGLE] 7/N Merged PCP and disable_padded interface (#6811)

### What this PR does / why we need it?
[Refactor][EAGLE] 7/N Merged PCP and disable_padded interface into
eagle_proposer.py

This pull request significantly refactors the speculative decoding
mechanism by merging Parallel Context Processing (PCP) and Multi-Token
Prediction (MTP) functionalities directly into the eagle_proposer.py.
The changes aim to enhance the efficiency and correctness of distributed
speculative decoding, particularly by enabling the Eagle feature to work
seamlessly with the disable_padded interface. This involves detailed
adjustments to attention metadata, input/output processing, and state
management to ensure proper operation in parallel environments.

1. The PCP and MTP features are migrated to the eagle_proposer.py
2. The Eagle and PCP features are integrated
3. Enable the eagle feature to use the disable_padded interface

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Tests and UT

- vLLM version: v0.15.0
- vLLM main:
83b47f67b1

---------

Signed-off-by: lilinsiman <lilinsiman@gmail.com>
This commit is contained in:
lilinsiman
2026-02-27 16:06:56 +08:00
committed by GitHub
parent e4458b2d2b
commit c13d90b766
6 changed files with 245 additions and 60 deletions

View File

@@ -118,6 +118,7 @@ class TestAscendAttentionCPImpl(TestBase):
attn_metadata = MagicMock()
attn_metadata.decode_meta = MagicMock()
attn_metadata.num_decodes_flatten = 5
attn_metadata.decode_meta.batch_seq_mask = torch.tensor(
[1, 0], dtype=torch.bool)
output = self.impl._forward_decode_pcp_dcp(query, attn_metadata)

View File

@@ -159,6 +159,7 @@ class AscendMetadata:
num_decode_tokens: int = 0
num_prefills: int = 0
num_decodes: int = 0
num_decodes_flatten: int = 0
# The sequence length per sequence. Sequence length means the computed
# tokens + new tokens (is None if it is a decoding).

View File

@@ -117,6 +117,7 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
block_table = common_attn_metadata.block_table_tensor
query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
self.num_decodes_flatten = query_lens[:num_decodes].sum().item()
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
@@ -146,7 +147,7 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
pcp_size = get_pcp_group().world_size
if self.chunked_prefill_enabled and max_context_len_cpu > 0:
local_context_lens_allranks = (
torch.tensor(num_computed_tokens_of_pcp_dcp)[num_decodes:num_reqs]
torch.tensor(num_computed_tokens_of_pcp_dcp)[self.num_decodes_flatten :]
.to(self.device)
.to(dtype=torch.int32)
)
@@ -214,23 +215,24 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
prefill_metadata = AscendMetadataForPrefill(
pcp_metadata=pcp_metadata,
chunked_context=chunked_context_metadata,
block_tables=block_table[num_decodes:],
block_tables=block_table[self.num_decodes_flatten :, ...],
actual_seq_lengths_q=torch.cumsum(query_lens, dim=0),
)
if num_decodes > 0:
num_computed_tokens_array = np.array(num_computed_tokens_of_pcp_dcp)
num_computed_tokens_array = num_computed_tokens_array[:num_decodes]
num_computed_tokens_array = num_computed_tokens_array[: self.num_decodes_flatten]
# TODO: numpy array mode of the shared memory is used to improve performance
decode_metadata = AscendMetadataForDecode(
num_computed_tokens_of_pcp_dcp=num_computed_tokens_array,
block_tables=block_table[:num_decodes],
block_tables=block_table[: self.num_decodes_flatten],
)
attn_metadata = AscendMetadata(
num_actual_tokens=num_actual_tokens,
num_decode_tokens=num_decode_tokens,
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
num_decodes_flatten=self.num_decodes_flatten,
block_tables=block_table,
query_start_loc=query_start_loc,
seq_lens=seq_lens,
@@ -550,7 +552,7 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
"actual_seq_lengths_kv": attn_metadata.decode_meta.num_computed_tokens_of_pcp_dcp[
:, self.pcp_rank, self.dcp_rank
],
"actual_seq_lengths": attn_metadata.actual_seq_lengths_q[: attn_metadata.num_decodes],
"actual_seq_lengths": torch.arange(attn_metadata.num_decodes_flatten) + 1,
}
graph_params = get_graph_params()
forward_context: ForwardContext = get_forward_context()

View File

@@ -10,6 +10,7 @@ import torch.nn as nn
import torch.nn.functional as F
from vllm.config import CompilationMode, CUDAGraphMode, VllmConfig, get_layers_from_vllm_config
from vllm.distributed.parallel_state import (
get_pcp_group,
get_pp_group,
get_tp_group,
get_world_group,
@@ -326,6 +327,12 @@ class EagleProposer(VllmEagleProposer):
decode_token_per_req=self.runner.decode_token_per_req,
max_seq_len=0,
)
if self.pcp_size * self.dcp_size > 1:
# update long_seq related params and flatten block_table
common_attn_metadata.prefill_context_parallel_metadata = self.runner.pcp_manager.long_seq_metadata
common_attn_metadata.block_table_tensor = self.runner.input_batch.block_table[0].get_device_tensor()[
: num_reqs * self.decode_threshold
]
builder = self.runner.attn_groups[0][0].get_metadata_builder()
# update the tensor's address for each step.
@@ -343,7 +350,9 @@ class EagleProposer(VllmEagleProposer):
model_positions = self._get_positions(num_tokens)
batch_size = num_tokens // (self.num_speculative_tokens + 1) if not is_profile else self.runner.max_num_reqs
batch_size = num_tokens // (self.num_speculative_tokens + 1) # if not is_profile else self.runner.max_num_reqs
if is_profile:
batch_size = min(batch_size, self.runner.max_num_reqs)
with set_ascend_forward_context(
multi_steps_attn_metadata[0] if multi_steps_attn_metadata else None,
@@ -371,6 +380,7 @@ class EagleProposer(VllmEagleProposer):
inputs_embeds=None,
multi_steps_attn_metadata=multi_steps_attn_metadata,
is_dummy=True,
num_tokens=num_tokens,
)
forward_context = get_forward_context()
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and not forward_context.capturing:
@@ -414,6 +424,62 @@ class EagleProposer(VllmEagleProposer):
# Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
self.input_ids[last_token_indices] = next_token_ids
assert self.runner is not None
# update pcp related params
if self.pcp_size * self.dcp_size > 1:
assert long_seq_metadata is not None
common_attn_metadata.prefill_context_parallel_metadata = long_seq_metadata
ori_last_token_indices = last_token_indices.clone()
query_lens_d = self.runner.query_lens[:num_decode_reqs]
if self.pcp_size > 1:
# 1. preprocess decode/prefill input_ids & target_hidden_states
# decode input_ids: keep unchanged
# decode target_hidden_states: remove padding
# prefill input_ids: add padding and pcp split
# prefill target_hidden_states: pcp split
num_tokens_d = query_lens_d.sum().item()
num_tokens_d_padded = num_tokens_d * self.pcp_size
input_ids_d = self.input_ids[:num_tokens_d]
input_ids_p = self.input_ids[num_tokens_d:num_tokens]
target_hidden_states_d_padded = target_hidden_states[:num_tokens_d_padded]
if num_tokens_d:
# remove padding (from pcp all-gather) in decode part
mask_start_loc = torch.cat(
[torch.tensor([0], dtype=torch.int32), torch.cumsum(query_lens_d * self.pcp_size, dim=0)[:-1]]
)
mask_len = query_lens_d
mask = []
for req_id in range(num_decode_reqs):
mask += list(range(mask_start_loc[req_id], mask_start_loc[req_id] + mask_len[req_id]))
target_hidden_states_d = target_hidden_states_d_padded[mask]
else:
target_hidden_states_d = target_hidden_states_d_padded
target_hidden_states_p = target_hidden_states[num_tokens_d_padded:]
req_scheduled_tokens_p = {}
for i, req_id in enumerate(self.runner.input_batch.req_ids):
if i >= num_decode_reqs:
req_scheduled_tokens_p[req_id] = req_scheduled_tokens[req_id]
(num_tokens_p, input_ids_p, target_hidden_states_p, max_query_len_p, seq_lens_p, cu_num_tokens_p) = (
self._split_pcp_input(req_scheduled_tokens_p, input_ids_p, target_hidden_states_p)
)
num_tokens = num_tokens_d + num_tokens_p
target_positions = target_positions[:num_tokens]
self.input_ids[:num_tokens].copy_(torch.cat([input_ids_d, input_ids_p], dim=0))
target_hidden_states = torch.cat([target_hidden_states_d, target_hidden_states_p], dim=0)
# 2. update sample_indices according to main model
if num_decode_reqs:
last_token_indices[:num_decode_reqs] = self.runner.logits_indices[last_token_indices[:num_decode_reqs]]
if num_prefill_reqs:
last_token_indices[-num_prefill_reqs:] = self.runner.logits_indices[-num_prefill_reqs:]
# 3. update attn_metadata params that may be influenced by pcp
common_attn_metadata.num_actual_tokens = num_tokens
common_attn_metadata.max_query_len = max(self.decode_threshold, max_query_len_p)
common_attn_metadata.seq_lens[-num_prefill_reqs:] = seq_lens_p
common_attn_metadata.seq_lens_cpu[-num_prefill_reqs:] = seq_lens_p
query_start_loc_p = cu_num_tokens_p[1:] + common_attn_metadata.query_start_loc[num_decode_reqs].item()
common_attn_metadata.query_start_loc[-num_prefill_reqs:] = query_start_loc_p
common_attn_metadata.query_start_loc_cpu[-num_prefill_reqs:] = query_start_loc_p
if self.use_cuda_graph and num_tokens <= self.runner.cudagraph_batch_sizes[-1]:
num_input_tokens = self.runner.cudagraph_dispatcher._bs_to_padded_graph_size[num_tokens]
if not (
@@ -468,11 +534,7 @@ class EagleProposer(VllmEagleProposer):
# only tensor which will be used in current FIA.
# Strictly speaking, `query_start_loc`, `seq_lens` should also have
# their memory allocated separately for each step just like `slot_mapping`.
slot_mapping_lens = (
num_input_tokens
if num_input_tokens < common_attn_metadata.slot_mapping.shape[0]
else common_attn_metadata.slot_mapping.shape[0]
)
slot_mapping_lens = common_attn_metadata.slot_mapping.shape[0]
self.slot_mapping_group[0][:slot_mapping_lens].copy_(common_attn_metadata.slot_mapping[:slot_mapping_lens])
self.slot_mapping_group[0][slot_mapping_lens:].fill_(-1)
common_attn_metadata.slot_mapping = self.slot_mapping_group[0][:slot_mapping_lens]
@@ -491,21 +553,87 @@ class EagleProposer(VllmEagleProposer):
per_layer_attn_metadata[layer_name] = attn_metadata
multi_steps_attn_metadata = [per_layer_attn_metadata]
# Copy the old attn_metadata and update
for draft_step in range(1, self.num_speculative_tokens):
common_attn_metadata, attn_metadata = self.attn_update_stack_num_spec_norm(
draft_step,
attn_metadata,
common_attn_metadata,
batch_size,
num_input_tokens,
used_update_positions,
aclgraph_runtime_mode,
)
per_layer_attn_metadata = dict()
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
multi_steps_attn_metadata.append(per_layer_attn_metadata)
attn_metadata_i = per_layer_attn_metadata[self.attn_layer_names[0]]
if self.pcp_size * self.dcp_size > 1:
if self.num_speculative_tokens > 1 and not attn_metadata_i.num_prefills:
# For pcp/dcp, tokens are split across different cp ranks,
# so we can not simply update slot_mapping by += 1.
# Instead, we pre-allocate mtp slot_mapping in model_runner
# (_generate_pcp_mtp_input), and use updated slot_indices
# to get corresponding slot_mapping in each step.
num_reject_tokens = (
torch.tensor(self.runner.pcp_manager.cu_num_tokens_pcp_full, dtype=torch.int32).to(self.device)
- ori_last_token_indices
- 1
)
num_accept_tokens = query_lens_d.to(self.device) - num_reject_tokens
ori_seq_len = attn_metadata_i.seq_lens[:batch_size].clone()
mtp_slot_mapping = self.runner.pcp_manager.mtp_slot_pad
# slot_mapping index base offset:
# scheduled tokens + pre-allocated mtp tokens + accepted tokens
slot_idx_base = (
torch.cat(
[
torch.tensor([0], dtype=torch.int32, device=self.device),
(torch.cumsum(query_lens_d, dim=0)[:-1] * self.pcp_size).to(self.device),
]
)
+ torch.arange(num_decode_reqs, device=self.device)
* (self.num_speculative_tokens - 1)
* self.pcp_size
+ (num_accept_tokens - 1) * self.pcp_size
)
slot_indices_list = []
for req_id in range(num_decode_reqs):
slot_indices_list.append(
torch.arange(slot_idx_base[req_id], slot_idx_base[req_id] + self.pcp_size, device=self.device)
)
slot_indices = torch.cat(slot_indices_list, dim=0)
# fold block_table (restore it to original size before flattened)
block_indices = torch.cat(
[torch.tensor([0], dtype=torch.int32), torch.cumsum(query_lens_d, dim=0)[:-1]]
)
common_attn_metadata.block_table_tensor[:batch_size] = common_attn_metadata.block_table_tensor[
block_indices
]
common_attn_metadata.block_table_tensor = common_attn_metadata.block_table_tensor[:batch_size]
# Copy the old attn_metadata and update
for draft_step in range(1, self.num_speculative_tokens):
common_attn_metadata, attn_metadata = self.attn_update_stack_num_spec_norm(
draft_step,
attn_metadata,
common_attn_metadata,
batch_size,
num_input_tokens,
used_update_positions,
aclgraph_runtime_mode,
ori_seq_len,
slot_indices,
mtp_slot_mapping,
)
per_layer_attn_metadata = dict()
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
multi_steps_attn_metadata.append(per_layer_attn_metadata)
else:
# Copy the old attn_metadata and update
for draft_step in range(1, self.num_speculative_tokens):
common_attn_metadata, attn_metadata = self.attn_update_stack_num_spec_norm(
draft_step,
attn_metadata,
common_attn_metadata,
batch_size,
num_input_tokens,
used_update_positions,
aclgraph_runtime_mode,
)
per_layer_attn_metadata = dict()
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
multi_steps_attn_metadata.append(per_layer_attn_metadata)
last_token_indices_len = last_token_indices.shape[0]
self.last_token_indices[:last_token_indices_len].copy_(last_token_indices)
@@ -533,6 +661,8 @@ class EagleProposer(VllmEagleProposer):
target_positions=target_positions,
inputs_embeds=inputs_embeds,
multi_steps_attn_metadata=multi_steps_attn_metadata,
num_tokens=num_tokens,
is_prefill=attn_metadata_i.num_prefills,
)
forward_context = get_forward_context()
@@ -548,7 +678,9 @@ class EagleProposer(VllmEagleProposer):
target_positions,
inputs_embeds,
multi_steps_attn_metadata,
num_tokens,
is_dummy=False,
is_prefill=None,
) -> torch.Tensor:
# The lifecycle of `input_ids`, `positions`, `hidden_states` runs through all
# speculative tokens' proposings. `model_input_ids`, `model_positions` and
@@ -575,6 +707,15 @@ class EagleProposer(VllmEagleProposer):
last_hidden_states, model_positions, hidden_states
)
if self.pcp_size > 1:
# remove graph padding before all_gather
hidden_states = hidden_states[:num_tokens]
hidden_states = get_pcp_group().all_gather(hidden_states, 0)
hidden_states = torch.index_select(
hidden_states, 0, self.runner.pcp_manager.pcp_allgather_restore_idx.gpu[: hidden_states.shape[0]]
)
last_hidden_states = hidden_states # TODO: check it
num_indices = last_token_indices.shape[0]
if lmhead_tp_enable() and not is_dummy:
max_num_reqs_across_dp = (
@@ -596,6 +737,13 @@ class EagleProposer(VllmEagleProposer):
# [batch_size, 1]
return draft_token_ids.view(-1, 1)
if self.pcp_size * self.dcp_size > 1 and is_prefill:
draft_token_ids = logits.argmax(dim=-1)
draft_token_ids_list = []
for _ in range(self.num_speculative_tokens):
draft_token_ids_list.append(draft_token_ids)
return torch.stack(draft_token_ids_list, dim=1)
# Generate the remaining draft tokens.
draft_token_ids_tensor = torch.zeros(
(self.num_speculative_tokens, *draft_token_ids.shape), dtype=draft_token_ids.dtype, device=self.device
@@ -722,6 +870,9 @@ class EagleProposer(VllmEagleProposer):
input_batch_size,
used_update_positions,
aclgraph_runtime_mode,
ori_seq_len=None,
slot_indices=None,
mtp_slot_mapping=None,
):
assert draft_step > 0
common_attn_metadata = self.shallow_copy_metadata(old_common_metadata)
@@ -797,28 +948,42 @@ class EagleProposer(VllmEagleProposer):
attn_metadata_builder = self._get_attention_metadata_builder()
else:
attn_metadata_builder = self.attn_metadata_builder
block_size = attn_metadata_builder.kv_cache_spec.block_size
# Compute the slot mapping.
if self.uses_mrope:
block_numbers = clamped_positions[0] // block_size
if self.pcp_size * self.dcp_size > 1:
num_computed_tokens_of_pcp_dcp = self.runner.pcp_manager._get_cp_local_seq_lens(
ori_seq_len + draft_step,
self.pcp_size,
self.dcp_size,
self.runner.parallel_config.cp_kv_cache_interleave_size,
)
cp_seq_len = num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank]
# update slot_mapping
slot_indices += self.pcp_size
slot_mapping = mtp_slot_mapping[slot_indices]
common_attn_metadata.slot_mapping[: batch_size * self.pcp_size] = slot_mapping
else:
block_numbers = clamped_positions // block_size
block_ids = old_common_metadata.block_table_tensor.gather(dim=1, index=block_numbers.view(-1, 1))
block_ids = block_ids.view(-1)
if self.uses_mrope:
slot_mapping = block_ids * block_size + clamped_positions[0] % block_size
else:
slot_mapping = block_ids * block_size + clamped_positions % block_size
block_size = attn_metadata_builder.kv_cache_spec.block_size
# Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the
# padding tokens.
slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID)
self.slot_mapping_group[draft_step][: slot_mapping.shape[0]].copy_(slot_mapping.to(torch.int32))
self.slot_mapping_group[draft_step][slot_mapping.shape[0] :].fill_(PADDING_SLOT_ID)
# Set the address of the attn_metadata.slot_mapping to the self.slot_mapping_group[idx]
common_attn_metadata.slot_mapping = self.slot_mapping_group[draft_step][: slot_mapping.shape[0]]
# Compute the slot mapping.
if self.uses_mrope:
block_numbers = clamped_positions[0] // block_size
else:
block_numbers = clamped_positions // block_size
block_ids = old_common_metadata.block_table_tensor.gather(dim=1, index=block_numbers.view(-1, 1))
block_ids = block_ids.view(-1)
if self.uses_mrope:
slot_mapping = block_ids * block_size + clamped_positions[0] % block_size
else:
slot_mapping = block_ids * block_size + clamped_positions % block_size
# Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the
# padding tokens.
slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID)
self.slot_mapping_group[draft_step][: slot_mapping.shape[0]].copy_(slot_mapping.to(torch.int32))
self.slot_mapping_group[draft_step][slot_mapping.shape[0] :].fill_(PADDING_SLOT_ID)
# Set the address of the attn_metadata.slot_mapping to the self.slot_mapping_group[idx]
common_attn_metadata.slot_mapping = self.slot_mapping_group[draft_step][: slot_mapping.shape[0]]
# Rebuild attention metadata
attn_metadata = attn_metadata_builder.build_for_drafting( # type: ignore
@@ -826,6 +991,12 @@ class EagleProposer(VllmEagleProposer):
draft_index=draft_step,
)
if self.pcp_size * self.dcp_size > 1:
if self.vllm_config.model_config.use_mla:
attn_metadata.decode.cp_seq_len = cp_seq_len
else:
attn_metadata.decode_meta.num_computed_tokens_of_pcp_dcp = num_computed_tokens_of_pcp_dcp
return common_attn_metadata, attn_metadata
def prepare_next_token_ids_padded(

View File

@@ -39,11 +39,7 @@ class MtpProposer(EagleProposer):
# Currently, both GLM and DS encounter issues when enabling the fullgraph mode and running on EagleProposer.
# Therefore, we temporarily bypass this problem by adding a conditional check for fullgraph.
# TODO: this conditional check should be removed after bug fixing.
if (
self.pcp_size * self.dcp_size == 1
and not self.speculative_config.disable_padded_drafter_batch
and not self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs()
):
if not self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs():
super().dummy_run(
num_tokens,
with_prefill,
@@ -175,11 +171,7 @@ class MtpProposer(EagleProposer):
# Currently, both GLM and DS encounter issues when enabling the fullgraph mode and running on EagleProposer.
# Therefore, we temporarily bypass this problem by adding a conditional check for fullgraph.
# TODO: this conditional check should be removed after bug fixing.
if (
self.pcp_size * self.dcp_size == 1
and not self.speculative_config.disable_padded_drafter_batch
and not self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs()
):
if not self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs():
draft_token_ids = super()._propose(
target_token_ids,
target_positions,

View File

@@ -561,7 +561,6 @@ class NPUModelRunner(GPUModelRunner):
dtype=np.int32,
)
attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens, num_valid_tokens)
self.attn_state = attn_state # type: ignore
# Determine if it's a splitfuse batch
with_prefill = attn_state not in [AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding]
@@ -800,7 +799,7 @@ class NPUModelRunner(GPUModelRunner):
attn_state = AscendAttentionState.SpecDecoding
# Speculative decoding.
elif np.all(num_valid_tokens == 1):
if self.speculative_config and self.speculative_config.method == "mtp":
if self.speculative_config:
attn_state = AscendAttentionState.SpecDecoding
else:
attn_state = AscendAttentionState.ChunkedPrefill
@@ -809,6 +808,14 @@ class NPUModelRunner(GPUModelRunner):
attn_state = AscendAttentionState.ChunkedPrefill
else:
attn_state = AscendAttentionState.PrefillCacheHit
# For the overlay of the PCP feature and the eagle3, attn_state needs to be recovered
# TODO: Resolved the conflict between the sunset of attn_state and the PCP that requires this interface.
if attn_state == AscendAttentionState.SpecDecoding and self.speculative_config.method != "mtp":
self.attn_state = AscendAttentionState.ChunkedPrefill # type: ignore
else:
self.attn_state = attn_state # type: ignore
return attn_state
def _calc_spec_decode_metadata(
@@ -977,6 +984,10 @@ class NPUModelRunner(GPUModelRunner):
target_token_ids = input_ids_pcp_full[:num_scheduled_tokens]
target_positions = self._get_positions(num_scheduled_tokens)
target_hidden_states = hidden_states
if self.use_aux_hidden_state_outputs:
target_hidden_states = torch.cat(
[h[:num_scheduled_tokens] for h in aux_hidden_states], dim=-1
)
else:
token_indices_to_sample = None
# input_ids can be None for multimodal models.
@@ -1014,6 +1025,8 @@ class NPUModelRunner(GPUModelRunner):
target_token_ids = input_ids_pcp_full[token_indices]
target_positions = positions
target_hidden_states = hidden_states
if self.use_aux_hidden_state_outputs:
target_hidden_states = torch.cat([h[token_indices] for h in aux_hidden_states], dim=-1)
else:
target_token_ids = self.input_ids.gpu[token_indices]
target_positions = self._get_positions(token_indices)
@@ -1260,13 +1273,18 @@ class NPUModelRunner(GPUModelRunner):
num_tokens_padded, input_ids, positions, intermediate_tensors, inputs_embeds, **model_kwargs
)
with record_function_or_nullcontext("post process"):
aux_hidden_states = None
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = hidden_states
if self.pcp_size > 1:
# NOTE we must `slice` hidden_states because pcp_allgather_restore_idx
# ignores the padding from CUDA Graph.
hidden_states = self.pcp_manager.get_restore_hidden_states(hidden_states)
aux_hidden_states = None
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = hidden_states
if aux_hidden_states is not None:
aux_hidden_states = [
self.pcp_manager.get_restore_hidden_states(aux_hidden_states_pcp)
for aux_hidden_states_pcp in aux_hidden_states
]
if not self.broadcast_pp_output:
# Common case.