[Refactor][EAGLE] 7/N Merged PCP and disable_padded interface (#6811)

### What this PR does / why we need it?
[Refactor][EAGLE] 7/N Merged PCP and disable_padded interface into
eagle_proposer.py

This pull request significantly refactors the speculative decoding
mechanism by merging Parallel Context Processing (PCP) and Multi-Token
Prediction (MTP) functionalities directly into the eagle_proposer.py.
The changes aim to enhance the efficiency and correctness of distributed
speculative decoding, particularly by enabling the Eagle feature to work
seamlessly with the disable_padded interface. This involves detailed
adjustments to attention metadata, input/output processing, and state
management to ensure proper operation in parallel environments.

1. The PCP and MTP features are migrated to the eagle_proposer.py
2. The Eagle and PCP features are integrated
3. Enable the eagle feature to use the disable_padded interface

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Tests and UT

- vLLM version: v0.15.0
- vLLM main:
83b47f67b1

---------

Signed-off-by: lilinsiman <lilinsiman@gmail.com>
This commit is contained in:
lilinsiman
2026-02-27 16:06:56 +08:00
committed by GitHub
parent e4458b2d2b
commit c13d90b766
6 changed files with 245 additions and 60 deletions

View File

@@ -118,6 +118,7 @@ class TestAscendAttentionCPImpl(TestBase):
attn_metadata = MagicMock() attn_metadata = MagicMock()
attn_metadata.decode_meta = MagicMock() attn_metadata.decode_meta = MagicMock()
attn_metadata.num_decodes_flatten = 5
attn_metadata.decode_meta.batch_seq_mask = torch.tensor( attn_metadata.decode_meta.batch_seq_mask = torch.tensor(
[1, 0], dtype=torch.bool) [1, 0], dtype=torch.bool)
output = self.impl._forward_decode_pcp_dcp(query, attn_metadata) output = self.impl._forward_decode_pcp_dcp(query, attn_metadata)

View File

@@ -159,6 +159,7 @@ class AscendMetadata:
num_decode_tokens: int = 0 num_decode_tokens: int = 0
num_prefills: int = 0 num_prefills: int = 0
num_decodes: int = 0 num_decodes: int = 0
num_decodes_flatten: int = 0
# The sequence length per sequence. Sequence length means the computed # The sequence length per sequence. Sequence length means the computed
# tokens + new tokens (is None if it is a decoding). # tokens + new tokens (is None if it is a decoding).

View File

@@ -117,6 +117,7 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
block_table = common_attn_metadata.block_table_tensor block_table = common_attn_metadata.block_table_tensor
query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
self.num_decodes_flatten = query_lens[:num_decodes].sum().item()
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
@@ -146,7 +147,7 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
pcp_size = get_pcp_group().world_size pcp_size = get_pcp_group().world_size
if self.chunked_prefill_enabled and max_context_len_cpu > 0: if self.chunked_prefill_enabled and max_context_len_cpu > 0:
local_context_lens_allranks = ( local_context_lens_allranks = (
torch.tensor(num_computed_tokens_of_pcp_dcp)[num_decodes:num_reqs] torch.tensor(num_computed_tokens_of_pcp_dcp)[self.num_decodes_flatten :]
.to(self.device) .to(self.device)
.to(dtype=torch.int32) .to(dtype=torch.int32)
) )
@@ -214,23 +215,24 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
prefill_metadata = AscendMetadataForPrefill( prefill_metadata = AscendMetadataForPrefill(
pcp_metadata=pcp_metadata, pcp_metadata=pcp_metadata,
chunked_context=chunked_context_metadata, chunked_context=chunked_context_metadata,
block_tables=block_table[num_decodes:], block_tables=block_table[self.num_decodes_flatten :, ...],
actual_seq_lengths_q=torch.cumsum(query_lens, dim=0), actual_seq_lengths_q=torch.cumsum(query_lens, dim=0),
) )
if num_decodes > 0: if num_decodes > 0:
num_computed_tokens_array = np.array(num_computed_tokens_of_pcp_dcp) num_computed_tokens_array = np.array(num_computed_tokens_of_pcp_dcp)
num_computed_tokens_array = num_computed_tokens_array[:num_decodes] num_computed_tokens_array = num_computed_tokens_array[: self.num_decodes_flatten]
# TODO: numpy array mode of the shared memory is used to improve performance # TODO: numpy array mode of the shared memory is used to improve performance
decode_metadata = AscendMetadataForDecode( decode_metadata = AscendMetadataForDecode(
num_computed_tokens_of_pcp_dcp=num_computed_tokens_array, num_computed_tokens_of_pcp_dcp=num_computed_tokens_array,
block_tables=block_table[:num_decodes], block_tables=block_table[: self.num_decodes_flatten],
) )
attn_metadata = AscendMetadata( attn_metadata = AscendMetadata(
num_actual_tokens=num_actual_tokens, num_actual_tokens=num_actual_tokens,
num_decode_tokens=num_decode_tokens, num_decode_tokens=num_decode_tokens,
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded, num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
num_decodes_flatten=self.num_decodes_flatten,
block_tables=block_table, block_tables=block_table,
query_start_loc=query_start_loc, query_start_loc=query_start_loc,
seq_lens=seq_lens, seq_lens=seq_lens,
@@ -550,7 +552,7 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
"actual_seq_lengths_kv": attn_metadata.decode_meta.num_computed_tokens_of_pcp_dcp[ "actual_seq_lengths_kv": attn_metadata.decode_meta.num_computed_tokens_of_pcp_dcp[
:, self.pcp_rank, self.dcp_rank :, self.pcp_rank, self.dcp_rank
], ],
"actual_seq_lengths": attn_metadata.actual_seq_lengths_q[: attn_metadata.num_decodes], "actual_seq_lengths": torch.arange(attn_metadata.num_decodes_flatten) + 1,
} }
graph_params = get_graph_params() graph_params = get_graph_params()
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()

View File

@@ -10,6 +10,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from vllm.config import CompilationMode, CUDAGraphMode, VllmConfig, get_layers_from_vllm_config from vllm.config import CompilationMode, CUDAGraphMode, VllmConfig, get_layers_from_vllm_config
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
get_pcp_group,
get_pp_group, get_pp_group,
get_tp_group, get_tp_group,
get_world_group, get_world_group,
@@ -326,6 +327,12 @@ class EagleProposer(VllmEagleProposer):
decode_token_per_req=self.runner.decode_token_per_req, decode_token_per_req=self.runner.decode_token_per_req,
max_seq_len=0, max_seq_len=0,
) )
if self.pcp_size * self.dcp_size > 1:
# update long_seq related params and flatten block_table
common_attn_metadata.prefill_context_parallel_metadata = self.runner.pcp_manager.long_seq_metadata
common_attn_metadata.block_table_tensor = self.runner.input_batch.block_table[0].get_device_tensor()[
: num_reqs * self.decode_threshold
]
builder = self.runner.attn_groups[0][0].get_metadata_builder() builder = self.runner.attn_groups[0][0].get_metadata_builder()
# update the tensor's address for each step. # update the tensor's address for each step.
@@ -343,7 +350,9 @@ class EagleProposer(VllmEagleProposer):
model_positions = self._get_positions(num_tokens) model_positions = self._get_positions(num_tokens)
batch_size = num_tokens // (self.num_speculative_tokens + 1) if not is_profile else self.runner.max_num_reqs batch_size = num_tokens // (self.num_speculative_tokens + 1) # if not is_profile else self.runner.max_num_reqs
if is_profile:
batch_size = min(batch_size, self.runner.max_num_reqs)
with set_ascend_forward_context( with set_ascend_forward_context(
multi_steps_attn_metadata[0] if multi_steps_attn_metadata else None, multi_steps_attn_metadata[0] if multi_steps_attn_metadata else None,
@@ -371,6 +380,7 @@ class EagleProposer(VllmEagleProposer):
inputs_embeds=None, inputs_embeds=None,
multi_steps_attn_metadata=multi_steps_attn_metadata, multi_steps_attn_metadata=multi_steps_attn_metadata,
is_dummy=True, is_dummy=True,
num_tokens=num_tokens,
) )
forward_context = get_forward_context() forward_context = get_forward_context()
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and not forward_context.capturing: if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and not forward_context.capturing:
@@ -414,6 +424,62 @@ class EagleProposer(VllmEagleProposer):
# Replace the last token with the next token. # Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
self.input_ids[last_token_indices] = next_token_ids self.input_ids[last_token_indices] = next_token_ids
assert self.runner is not None
# update pcp related params
if self.pcp_size * self.dcp_size > 1:
assert long_seq_metadata is not None
common_attn_metadata.prefill_context_parallel_metadata = long_seq_metadata
ori_last_token_indices = last_token_indices.clone()
query_lens_d = self.runner.query_lens[:num_decode_reqs]
if self.pcp_size > 1:
# 1. preprocess decode/prefill input_ids & target_hidden_states
# decode input_ids: keep unchanged
# decode target_hidden_states: remove padding
# prefill input_ids: add padding and pcp split
# prefill target_hidden_states: pcp split
num_tokens_d = query_lens_d.sum().item()
num_tokens_d_padded = num_tokens_d * self.pcp_size
input_ids_d = self.input_ids[:num_tokens_d]
input_ids_p = self.input_ids[num_tokens_d:num_tokens]
target_hidden_states_d_padded = target_hidden_states[:num_tokens_d_padded]
if num_tokens_d:
# remove padding (from pcp all-gather) in decode part
mask_start_loc = torch.cat(
[torch.tensor([0], dtype=torch.int32), torch.cumsum(query_lens_d * self.pcp_size, dim=0)[:-1]]
)
mask_len = query_lens_d
mask = []
for req_id in range(num_decode_reqs):
mask += list(range(mask_start_loc[req_id], mask_start_loc[req_id] + mask_len[req_id]))
target_hidden_states_d = target_hidden_states_d_padded[mask]
else:
target_hidden_states_d = target_hidden_states_d_padded
target_hidden_states_p = target_hidden_states[num_tokens_d_padded:]
req_scheduled_tokens_p = {}
for i, req_id in enumerate(self.runner.input_batch.req_ids):
if i >= num_decode_reqs:
req_scheduled_tokens_p[req_id] = req_scheduled_tokens[req_id]
(num_tokens_p, input_ids_p, target_hidden_states_p, max_query_len_p, seq_lens_p, cu_num_tokens_p) = (
self._split_pcp_input(req_scheduled_tokens_p, input_ids_p, target_hidden_states_p)
)
num_tokens = num_tokens_d + num_tokens_p
target_positions = target_positions[:num_tokens]
self.input_ids[:num_tokens].copy_(torch.cat([input_ids_d, input_ids_p], dim=0))
target_hidden_states = torch.cat([target_hidden_states_d, target_hidden_states_p], dim=0)
# 2. update sample_indices according to main model
if num_decode_reqs:
last_token_indices[:num_decode_reqs] = self.runner.logits_indices[last_token_indices[:num_decode_reqs]]
if num_prefill_reqs:
last_token_indices[-num_prefill_reqs:] = self.runner.logits_indices[-num_prefill_reqs:]
# 3. update attn_metadata params that may be influenced by pcp
common_attn_metadata.num_actual_tokens = num_tokens
common_attn_metadata.max_query_len = max(self.decode_threshold, max_query_len_p)
common_attn_metadata.seq_lens[-num_prefill_reqs:] = seq_lens_p
common_attn_metadata.seq_lens_cpu[-num_prefill_reqs:] = seq_lens_p
query_start_loc_p = cu_num_tokens_p[1:] + common_attn_metadata.query_start_loc[num_decode_reqs].item()
common_attn_metadata.query_start_loc[-num_prefill_reqs:] = query_start_loc_p
common_attn_metadata.query_start_loc_cpu[-num_prefill_reqs:] = query_start_loc_p
if self.use_cuda_graph and num_tokens <= self.runner.cudagraph_batch_sizes[-1]: if self.use_cuda_graph and num_tokens <= self.runner.cudagraph_batch_sizes[-1]:
num_input_tokens = self.runner.cudagraph_dispatcher._bs_to_padded_graph_size[num_tokens] num_input_tokens = self.runner.cudagraph_dispatcher._bs_to_padded_graph_size[num_tokens]
if not ( if not (
@@ -468,11 +534,7 @@ class EagleProposer(VllmEagleProposer):
# only tensor which will be used in current FIA. # only tensor which will be used in current FIA.
# Strictly speaking, `query_start_loc`, `seq_lens` should also have # Strictly speaking, `query_start_loc`, `seq_lens` should also have
# their memory allocated separately for each step just like `slot_mapping`. # their memory allocated separately for each step just like `slot_mapping`.
slot_mapping_lens = ( slot_mapping_lens = common_attn_metadata.slot_mapping.shape[0]
num_input_tokens
if num_input_tokens < common_attn_metadata.slot_mapping.shape[0]
else common_attn_metadata.slot_mapping.shape[0]
)
self.slot_mapping_group[0][:slot_mapping_lens].copy_(common_attn_metadata.slot_mapping[:slot_mapping_lens]) self.slot_mapping_group[0][:slot_mapping_lens].copy_(common_attn_metadata.slot_mapping[:slot_mapping_lens])
self.slot_mapping_group[0][slot_mapping_lens:].fill_(-1) self.slot_mapping_group[0][slot_mapping_lens:].fill_(-1)
common_attn_metadata.slot_mapping = self.slot_mapping_group[0][:slot_mapping_lens] common_attn_metadata.slot_mapping = self.slot_mapping_group[0][:slot_mapping_lens]
@@ -491,21 +553,87 @@ class EagleProposer(VllmEagleProposer):
per_layer_attn_metadata[layer_name] = attn_metadata per_layer_attn_metadata[layer_name] = attn_metadata
multi_steps_attn_metadata = [per_layer_attn_metadata] multi_steps_attn_metadata = [per_layer_attn_metadata]
# Copy the old attn_metadata and update attn_metadata_i = per_layer_attn_metadata[self.attn_layer_names[0]]
for draft_step in range(1, self.num_speculative_tokens): if self.pcp_size * self.dcp_size > 1:
common_attn_metadata, attn_metadata = self.attn_update_stack_num_spec_norm( if self.num_speculative_tokens > 1 and not attn_metadata_i.num_prefills:
draft_step, # For pcp/dcp, tokens are split across different cp ranks,
attn_metadata, # so we can not simply update slot_mapping by += 1.
common_attn_metadata, # Instead, we pre-allocate mtp slot_mapping in model_runner
batch_size, # (_generate_pcp_mtp_input), and use updated slot_indices
num_input_tokens, # to get corresponding slot_mapping in each step.
used_update_positions, num_reject_tokens = (
aclgraph_runtime_mode, torch.tensor(self.runner.pcp_manager.cu_num_tokens_pcp_full, dtype=torch.int32).to(self.device)
) - ori_last_token_indices
per_layer_attn_metadata = dict() - 1
for layer_name in self.attn_layer_names: )
per_layer_attn_metadata[layer_name] = attn_metadata num_accept_tokens = query_lens_d.to(self.device) - num_reject_tokens
multi_steps_attn_metadata.append(per_layer_attn_metadata) ori_seq_len = attn_metadata_i.seq_lens[:batch_size].clone()
mtp_slot_mapping = self.runner.pcp_manager.mtp_slot_pad
# slot_mapping index base offset:
# scheduled tokens + pre-allocated mtp tokens + accepted tokens
slot_idx_base = (
torch.cat(
[
torch.tensor([0], dtype=torch.int32, device=self.device),
(torch.cumsum(query_lens_d, dim=0)[:-1] * self.pcp_size).to(self.device),
]
)
+ torch.arange(num_decode_reqs, device=self.device)
* (self.num_speculative_tokens - 1)
* self.pcp_size
+ (num_accept_tokens - 1) * self.pcp_size
)
slot_indices_list = []
for req_id in range(num_decode_reqs):
slot_indices_list.append(
torch.arange(slot_idx_base[req_id], slot_idx_base[req_id] + self.pcp_size, device=self.device)
)
slot_indices = torch.cat(slot_indices_list, dim=0)
# fold block_table (restore it to original size before flattened)
block_indices = torch.cat(
[torch.tensor([0], dtype=torch.int32), torch.cumsum(query_lens_d, dim=0)[:-1]]
)
common_attn_metadata.block_table_tensor[:batch_size] = common_attn_metadata.block_table_tensor[
block_indices
]
common_attn_metadata.block_table_tensor = common_attn_metadata.block_table_tensor[:batch_size]
# Copy the old attn_metadata and update
for draft_step in range(1, self.num_speculative_tokens):
common_attn_metadata, attn_metadata = self.attn_update_stack_num_spec_norm(
draft_step,
attn_metadata,
common_attn_metadata,
batch_size,
num_input_tokens,
used_update_positions,
aclgraph_runtime_mode,
ori_seq_len,
slot_indices,
mtp_slot_mapping,
)
per_layer_attn_metadata = dict()
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
multi_steps_attn_metadata.append(per_layer_attn_metadata)
else:
# Copy the old attn_metadata and update
for draft_step in range(1, self.num_speculative_tokens):
common_attn_metadata, attn_metadata = self.attn_update_stack_num_spec_norm(
draft_step,
attn_metadata,
common_attn_metadata,
batch_size,
num_input_tokens,
used_update_positions,
aclgraph_runtime_mode,
)
per_layer_attn_metadata = dict()
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
multi_steps_attn_metadata.append(per_layer_attn_metadata)
last_token_indices_len = last_token_indices.shape[0] last_token_indices_len = last_token_indices.shape[0]
self.last_token_indices[:last_token_indices_len].copy_(last_token_indices) self.last_token_indices[:last_token_indices_len].copy_(last_token_indices)
@@ -533,6 +661,8 @@ class EagleProposer(VllmEagleProposer):
target_positions=target_positions, target_positions=target_positions,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
multi_steps_attn_metadata=multi_steps_attn_metadata, multi_steps_attn_metadata=multi_steps_attn_metadata,
num_tokens=num_tokens,
is_prefill=attn_metadata_i.num_prefills,
) )
forward_context = get_forward_context() forward_context = get_forward_context()
@@ -548,7 +678,9 @@ class EagleProposer(VllmEagleProposer):
target_positions, target_positions,
inputs_embeds, inputs_embeds,
multi_steps_attn_metadata, multi_steps_attn_metadata,
num_tokens,
is_dummy=False, is_dummy=False,
is_prefill=None,
) -> torch.Tensor: ) -> torch.Tensor:
# The lifecycle of `input_ids`, `positions`, `hidden_states` runs through all # The lifecycle of `input_ids`, `positions`, `hidden_states` runs through all
# speculative tokens' proposings. `model_input_ids`, `model_positions` and # speculative tokens' proposings. `model_input_ids`, `model_positions` and
@@ -575,6 +707,15 @@ class EagleProposer(VllmEagleProposer):
last_hidden_states, model_positions, hidden_states last_hidden_states, model_positions, hidden_states
) )
if self.pcp_size > 1:
# remove graph padding before all_gather
hidden_states = hidden_states[:num_tokens]
hidden_states = get_pcp_group().all_gather(hidden_states, 0)
hidden_states = torch.index_select(
hidden_states, 0, self.runner.pcp_manager.pcp_allgather_restore_idx.gpu[: hidden_states.shape[0]]
)
last_hidden_states = hidden_states # TODO: check it
num_indices = last_token_indices.shape[0] num_indices = last_token_indices.shape[0]
if lmhead_tp_enable() and not is_dummy: if lmhead_tp_enable() and not is_dummy:
max_num_reqs_across_dp = ( max_num_reqs_across_dp = (
@@ -596,6 +737,13 @@ class EagleProposer(VllmEagleProposer):
# [batch_size, 1] # [batch_size, 1]
return draft_token_ids.view(-1, 1) return draft_token_ids.view(-1, 1)
if self.pcp_size * self.dcp_size > 1 and is_prefill:
draft_token_ids = logits.argmax(dim=-1)
draft_token_ids_list = []
for _ in range(self.num_speculative_tokens):
draft_token_ids_list.append(draft_token_ids)
return torch.stack(draft_token_ids_list, dim=1)
# Generate the remaining draft tokens. # Generate the remaining draft tokens.
draft_token_ids_tensor = torch.zeros( draft_token_ids_tensor = torch.zeros(
(self.num_speculative_tokens, *draft_token_ids.shape), dtype=draft_token_ids.dtype, device=self.device (self.num_speculative_tokens, *draft_token_ids.shape), dtype=draft_token_ids.dtype, device=self.device
@@ -722,6 +870,9 @@ class EagleProposer(VllmEagleProposer):
input_batch_size, input_batch_size,
used_update_positions, used_update_positions,
aclgraph_runtime_mode, aclgraph_runtime_mode,
ori_seq_len=None,
slot_indices=None,
mtp_slot_mapping=None,
): ):
assert draft_step > 0 assert draft_step > 0
common_attn_metadata = self.shallow_copy_metadata(old_common_metadata) common_attn_metadata = self.shallow_copy_metadata(old_common_metadata)
@@ -797,28 +948,42 @@ class EagleProposer(VllmEagleProposer):
attn_metadata_builder = self._get_attention_metadata_builder() attn_metadata_builder = self._get_attention_metadata_builder()
else: else:
attn_metadata_builder = self.attn_metadata_builder attn_metadata_builder = self.attn_metadata_builder
block_size = attn_metadata_builder.kv_cache_spec.block_size
# Compute the slot mapping. if self.pcp_size * self.dcp_size > 1:
if self.uses_mrope: num_computed_tokens_of_pcp_dcp = self.runner.pcp_manager._get_cp_local_seq_lens(
block_numbers = clamped_positions[0] // block_size ori_seq_len + draft_step,
self.pcp_size,
self.dcp_size,
self.runner.parallel_config.cp_kv_cache_interleave_size,
)
cp_seq_len = num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank]
# update slot_mapping
slot_indices += self.pcp_size
slot_mapping = mtp_slot_mapping[slot_indices]
common_attn_metadata.slot_mapping[: batch_size * self.pcp_size] = slot_mapping
else: else:
block_numbers = clamped_positions // block_size block_size = attn_metadata_builder.kv_cache_spec.block_size
block_ids = old_common_metadata.block_table_tensor.gather(dim=1, index=block_numbers.view(-1, 1))
block_ids = block_ids.view(-1)
if self.uses_mrope:
slot_mapping = block_ids * block_size + clamped_positions[0] % block_size
else:
slot_mapping = block_ids * block_size + clamped_positions % block_size
# Mask out the slot mappings that exceed the max model length. # Compute the slot mapping.
# Otherwise, the KV cache will be inadvertently updated with the if self.uses_mrope:
# padding tokens. block_numbers = clamped_positions[0] // block_size
slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID) else:
self.slot_mapping_group[draft_step][: slot_mapping.shape[0]].copy_(slot_mapping.to(torch.int32)) block_numbers = clamped_positions // block_size
self.slot_mapping_group[draft_step][slot_mapping.shape[0] :].fill_(PADDING_SLOT_ID) block_ids = old_common_metadata.block_table_tensor.gather(dim=1, index=block_numbers.view(-1, 1))
# Set the address of the attn_metadata.slot_mapping to the self.slot_mapping_group[idx] block_ids = block_ids.view(-1)
common_attn_metadata.slot_mapping = self.slot_mapping_group[draft_step][: slot_mapping.shape[0]] if self.uses_mrope:
slot_mapping = block_ids * block_size + clamped_positions[0] % block_size
else:
slot_mapping = block_ids * block_size + clamped_positions % block_size
# Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the
# padding tokens.
slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID)
self.slot_mapping_group[draft_step][: slot_mapping.shape[0]].copy_(slot_mapping.to(torch.int32))
self.slot_mapping_group[draft_step][slot_mapping.shape[0] :].fill_(PADDING_SLOT_ID)
# Set the address of the attn_metadata.slot_mapping to the self.slot_mapping_group[idx]
common_attn_metadata.slot_mapping = self.slot_mapping_group[draft_step][: slot_mapping.shape[0]]
# Rebuild attention metadata # Rebuild attention metadata
attn_metadata = attn_metadata_builder.build_for_drafting( # type: ignore attn_metadata = attn_metadata_builder.build_for_drafting( # type: ignore
@@ -826,6 +991,12 @@ class EagleProposer(VllmEagleProposer):
draft_index=draft_step, draft_index=draft_step,
) )
if self.pcp_size * self.dcp_size > 1:
if self.vllm_config.model_config.use_mla:
attn_metadata.decode.cp_seq_len = cp_seq_len
else:
attn_metadata.decode_meta.num_computed_tokens_of_pcp_dcp = num_computed_tokens_of_pcp_dcp
return common_attn_metadata, attn_metadata return common_attn_metadata, attn_metadata
def prepare_next_token_ids_padded( def prepare_next_token_ids_padded(

View File

@@ -39,11 +39,7 @@ class MtpProposer(EagleProposer):
# Currently, both GLM and DS encounter issues when enabling the fullgraph mode and running on EagleProposer. # Currently, both GLM and DS encounter issues when enabling the fullgraph mode and running on EagleProposer.
# Therefore, we temporarily bypass this problem by adding a conditional check for fullgraph. # Therefore, we temporarily bypass this problem by adding a conditional check for fullgraph.
# TODO: this conditional check should be removed after bug fixing. # TODO: this conditional check should be removed after bug fixing.
if ( if not self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs():
self.pcp_size * self.dcp_size == 1
and not self.speculative_config.disable_padded_drafter_batch
and not self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs()
):
super().dummy_run( super().dummy_run(
num_tokens, num_tokens,
with_prefill, with_prefill,
@@ -175,11 +171,7 @@ class MtpProposer(EagleProposer):
# Currently, both GLM and DS encounter issues when enabling the fullgraph mode and running on EagleProposer. # Currently, both GLM and DS encounter issues when enabling the fullgraph mode and running on EagleProposer.
# Therefore, we temporarily bypass this problem by adding a conditional check for fullgraph. # Therefore, we temporarily bypass this problem by adding a conditional check for fullgraph.
# TODO: this conditional check should be removed after bug fixing. # TODO: this conditional check should be removed after bug fixing.
if ( if not self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs():
self.pcp_size * self.dcp_size == 1
and not self.speculative_config.disable_padded_drafter_batch
and not self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs()
):
draft_token_ids = super()._propose( draft_token_ids = super()._propose(
target_token_ids, target_token_ids,
target_positions, target_positions,

View File

@@ -561,7 +561,6 @@ class NPUModelRunner(GPUModelRunner):
dtype=np.int32, dtype=np.int32,
) )
attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens, num_valid_tokens) attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens, num_valid_tokens)
self.attn_state = attn_state # type: ignore
# Determine if it's a splitfuse batch # Determine if it's a splitfuse batch
with_prefill = attn_state not in [AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding] with_prefill = attn_state not in [AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding]
@@ -800,7 +799,7 @@ class NPUModelRunner(GPUModelRunner):
attn_state = AscendAttentionState.SpecDecoding attn_state = AscendAttentionState.SpecDecoding
# Speculative decoding. # Speculative decoding.
elif np.all(num_valid_tokens == 1): elif np.all(num_valid_tokens == 1):
if self.speculative_config and self.speculative_config.method == "mtp": if self.speculative_config:
attn_state = AscendAttentionState.SpecDecoding attn_state = AscendAttentionState.SpecDecoding
else: else:
attn_state = AscendAttentionState.ChunkedPrefill attn_state = AscendAttentionState.ChunkedPrefill
@@ -809,6 +808,14 @@ class NPUModelRunner(GPUModelRunner):
attn_state = AscendAttentionState.ChunkedPrefill attn_state = AscendAttentionState.ChunkedPrefill
else: else:
attn_state = AscendAttentionState.PrefillCacheHit attn_state = AscendAttentionState.PrefillCacheHit
# For the overlay of the PCP feature and the eagle3, attn_state needs to be recovered
# TODO: Resolved the conflict between the sunset of attn_state and the PCP that requires this interface.
if attn_state == AscendAttentionState.SpecDecoding and self.speculative_config.method != "mtp":
self.attn_state = AscendAttentionState.ChunkedPrefill # type: ignore
else:
self.attn_state = attn_state # type: ignore
return attn_state return attn_state
def _calc_spec_decode_metadata( def _calc_spec_decode_metadata(
@@ -977,6 +984,10 @@ class NPUModelRunner(GPUModelRunner):
target_token_ids = input_ids_pcp_full[:num_scheduled_tokens] target_token_ids = input_ids_pcp_full[:num_scheduled_tokens]
target_positions = self._get_positions(num_scheduled_tokens) target_positions = self._get_positions(num_scheduled_tokens)
target_hidden_states = hidden_states target_hidden_states = hidden_states
if self.use_aux_hidden_state_outputs:
target_hidden_states = torch.cat(
[h[:num_scheduled_tokens] for h in aux_hidden_states], dim=-1
)
else: else:
token_indices_to_sample = None token_indices_to_sample = None
# input_ids can be None for multimodal models. # input_ids can be None for multimodal models.
@@ -1014,6 +1025,8 @@ class NPUModelRunner(GPUModelRunner):
target_token_ids = input_ids_pcp_full[token_indices] target_token_ids = input_ids_pcp_full[token_indices]
target_positions = positions target_positions = positions
target_hidden_states = hidden_states target_hidden_states = hidden_states
if self.use_aux_hidden_state_outputs:
target_hidden_states = torch.cat([h[token_indices] for h in aux_hidden_states], dim=-1)
else: else:
target_token_ids = self.input_ids.gpu[token_indices] target_token_ids = self.input_ids.gpu[token_indices]
target_positions = self._get_positions(token_indices) target_positions = self._get_positions(token_indices)
@@ -1260,13 +1273,18 @@ class NPUModelRunner(GPUModelRunner):
num_tokens_padded, input_ids, positions, intermediate_tensors, inputs_embeds, **model_kwargs num_tokens_padded, input_ids, positions, intermediate_tensors, inputs_embeds, **model_kwargs
) )
with record_function_or_nullcontext("post process"): with record_function_or_nullcontext("post process"):
aux_hidden_states = None
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = hidden_states
if self.pcp_size > 1: if self.pcp_size > 1:
# NOTE we must `slice` hidden_states because pcp_allgather_restore_idx # NOTE we must `slice` hidden_states because pcp_allgather_restore_idx
# ignores the padding from CUDA Graph. # ignores the padding from CUDA Graph.
hidden_states = self.pcp_manager.get_restore_hidden_states(hidden_states) hidden_states = self.pcp_manager.get_restore_hidden_states(hidden_states)
aux_hidden_states = None if aux_hidden_states is not None:
if self.use_aux_hidden_state_outputs: aux_hidden_states = [
hidden_states, aux_hidden_states = hidden_states self.pcp_manager.get_restore_hidden_states(aux_hidden_states_pcp)
for aux_hidden_states_pcp in aux_hidden_states
]
if not self.broadcast_pp_output: if not self.broadcast_pp_output:
# Common case. # Common case.