[Refactor][EAGLE] 7/N Merged PCP and disable_padded interface (#6811)
### What this PR does / why we need it?
[Refactor][EAGLE] 7/N Merged PCP and disable_padded interface into
eagle_proposer.py
This pull request significantly refactors the speculative decoding
mechanism by merging Parallel Context Processing (PCP) and Multi-Token
Prediction (MTP) functionalities directly into the eagle_proposer.py.
The changes aim to enhance the efficiency and correctness of distributed
speculative decoding, particularly by enabling the Eagle feature to work
seamlessly with the disable_padded interface. This involves detailed
adjustments to attention metadata, input/output processing, and state
management to ensure proper operation in parallel environments.
1. The PCP and MTP features are migrated to the eagle_proposer.py
2. The Eagle and PCP features are integrated
3. Enable the eagle feature to use the disable_padded interface
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Tests and UT
- vLLM version: v0.15.0
- vLLM main:
83b47f67b1
---------
Signed-off-by: lilinsiman <lilinsiman@gmail.com>
This commit is contained in:
@@ -118,6 +118,7 @@ class TestAscendAttentionCPImpl(TestBase):
|
||||
|
||||
attn_metadata = MagicMock()
|
||||
attn_metadata.decode_meta = MagicMock()
|
||||
attn_metadata.num_decodes_flatten = 5
|
||||
attn_metadata.decode_meta.batch_seq_mask = torch.tensor(
|
||||
[1, 0], dtype=torch.bool)
|
||||
output = self.impl._forward_decode_pcp_dcp(query, attn_metadata)
|
||||
|
||||
@@ -159,6 +159,7 @@ class AscendMetadata:
|
||||
num_decode_tokens: int = 0
|
||||
num_prefills: int = 0
|
||||
num_decodes: int = 0
|
||||
num_decodes_flatten: int = 0
|
||||
|
||||
# The sequence length per sequence. Sequence length means the computed
|
||||
# tokens + new tokens (is None if it is a decoding).
|
||||
|
||||
@@ -117,6 +117,7 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
|
||||
|
||||
block_table = common_attn_metadata.block_table_tensor
|
||||
query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
self.num_decodes_flatten = query_lens[:num_decodes].sum().item()
|
||||
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
|
||||
|
||||
long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
|
||||
@@ -146,7 +147,7 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
|
||||
pcp_size = get_pcp_group().world_size
|
||||
if self.chunked_prefill_enabled and max_context_len_cpu > 0:
|
||||
local_context_lens_allranks = (
|
||||
torch.tensor(num_computed_tokens_of_pcp_dcp)[num_decodes:num_reqs]
|
||||
torch.tensor(num_computed_tokens_of_pcp_dcp)[self.num_decodes_flatten :]
|
||||
.to(self.device)
|
||||
.to(dtype=torch.int32)
|
||||
)
|
||||
@@ -214,23 +215,24 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
|
||||
prefill_metadata = AscendMetadataForPrefill(
|
||||
pcp_metadata=pcp_metadata,
|
||||
chunked_context=chunked_context_metadata,
|
||||
block_tables=block_table[num_decodes:],
|
||||
block_tables=block_table[self.num_decodes_flatten :, ...],
|
||||
actual_seq_lengths_q=torch.cumsum(query_lens, dim=0),
|
||||
)
|
||||
|
||||
if num_decodes > 0:
|
||||
num_computed_tokens_array = np.array(num_computed_tokens_of_pcp_dcp)
|
||||
num_computed_tokens_array = num_computed_tokens_array[:num_decodes]
|
||||
num_computed_tokens_array = num_computed_tokens_array[: self.num_decodes_flatten]
|
||||
# TODO: numpy array mode of the shared memory is used to improve performance
|
||||
decode_metadata = AscendMetadataForDecode(
|
||||
num_computed_tokens_of_pcp_dcp=num_computed_tokens_array,
|
||||
block_tables=block_table[:num_decodes],
|
||||
block_tables=block_table[: self.num_decodes_flatten],
|
||||
)
|
||||
|
||||
attn_metadata = AscendMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
|
||||
num_decodes_flatten=self.num_decodes_flatten,
|
||||
block_tables=block_table,
|
||||
query_start_loc=query_start_loc,
|
||||
seq_lens=seq_lens,
|
||||
@@ -550,7 +552,7 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
|
||||
"actual_seq_lengths_kv": attn_metadata.decode_meta.num_computed_tokens_of_pcp_dcp[
|
||||
:, self.pcp_rank, self.dcp_rank
|
||||
],
|
||||
"actual_seq_lengths": attn_metadata.actual_seq_lengths_q[: attn_metadata.num_decodes],
|
||||
"actual_seq_lengths": torch.arange(attn_metadata.num_decodes_flatten) + 1,
|
||||
}
|
||||
graph_params = get_graph_params()
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
|
||||
@@ -10,6 +10,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from vllm.config import CompilationMode, CUDAGraphMode, VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_pcp_group,
|
||||
get_pp_group,
|
||||
get_tp_group,
|
||||
get_world_group,
|
||||
@@ -326,6 +327,12 @@ class EagleProposer(VllmEagleProposer):
|
||||
decode_token_per_req=self.runner.decode_token_per_req,
|
||||
max_seq_len=0,
|
||||
)
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
# update long_seq related params and flatten block_table
|
||||
common_attn_metadata.prefill_context_parallel_metadata = self.runner.pcp_manager.long_seq_metadata
|
||||
common_attn_metadata.block_table_tensor = self.runner.input_batch.block_table[0].get_device_tensor()[
|
||||
: num_reqs * self.decode_threshold
|
||||
]
|
||||
|
||||
builder = self.runner.attn_groups[0][0].get_metadata_builder()
|
||||
# update the tensor's address for each step.
|
||||
@@ -343,7 +350,9 @@ class EagleProposer(VllmEagleProposer):
|
||||
|
||||
model_positions = self._get_positions(num_tokens)
|
||||
|
||||
batch_size = num_tokens // (self.num_speculative_tokens + 1) if not is_profile else self.runner.max_num_reqs
|
||||
batch_size = num_tokens // (self.num_speculative_tokens + 1) # if not is_profile else self.runner.max_num_reqs
|
||||
if is_profile:
|
||||
batch_size = min(batch_size, self.runner.max_num_reqs)
|
||||
|
||||
with set_ascend_forward_context(
|
||||
multi_steps_attn_metadata[0] if multi_steps_attn_metadata else None,
|
||||
@@ -371,6 +380,7 @@ class EagleProposer(VllmEagleProposer):
|
||||
inputs_embeds=None,
|
||||
multi_steps_attn_metadata=multi_steps_attn_metadata,
|
||||
is_dummy=True,
|
||||
num_tokens=num_tokens,
|
||||
)
|
||||
forward_context = get_forward_context()
|
||||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and not forward_context.capturing:
|
||||
@@ -414,6 +424,62 @@ class EagleProposer(VllmEagleProposer):
|
||||
# Replace the last token with the next token.
|
||||
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
|
||||
self.input_ids[last_token_indices] = next_token_ids
|
||||
|
||||
assert self.runner is not None
|
||||
# update pcp related params
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
assert long_seq_metadata is not None
|
||||
common_attn_metadata.prefill_context_parallel_metadata = long_seq_metadata
|
||||
ori_last_token_indices = last_token_indices.clone()
|
||||
query_lens_d = self.runner.query_lens[:num_decode_reqs]
|
||||
if self.pcp_size > 1:
|
||||
# 1. preprocess decode/prefill input_ids & target_hidden_states
|
||||
# decode input_ids: keep unchanged
|
||||
# decode target_hidden_states: remove padding
|
||||
# prefill input_ids: add padding and pcp split
|
||||
# prefill target_hidden_states: pcp split
|
||||
num_tokens_d = query_lens_d.sum().item()
|
||||
num_tokens_d_padded = num_tokens_d * self.pcp_size
|
||||
input_ids_d = self.input_ids[:num_tokens_d]
|
||||
input_ids_p = self.input_ids[num_tokens_d:num_tokens]
|
||||
target_hidden_states_d_padded = target_hidden_states[:num_tokens_d_padded]
|
||||
if num_tokens_d:
|
||||
# remove padding (from pcp all-gather) in decode part
|
||||
mask_start_loc = torch.cat(
|
||||
[torch.tensor([0], dtype=torch.int32), torch.cumsum(query_lens_d * self.pcp_size, dim=0)[:-1]]
|
||||
)
|
||||
mask_len = query_lens_d
|
||||
mask = []
|
||||
for req_id in range(num_decode_reqs):
|
||||
mask += list(range(mask_start_loc[req_id], mask_start_loc[req_id] + mask_len[req_id]))
|
||||
target_hidden_states_d = target_hidden_states_d_padded[mask]
|
||||
else:
|
||||
target_hidden_states_d = target_hidden_states_d_padded
|
||||
target_hidden_states_p = target_hidden_states[num_tokens_d_padded:]
|
||||
req_scheduled_tokens_p = {}
|
||||
for i, req_id in enumerate(self.runner.input_batch.req_ids):
|
||||
if i >= num_decode_reqs:
|
||||
req_scheduled_tokens_p[req_id] = req_scheduled_tokens[req_id]
|
||||
(num_tokens_p, input_ids_p, target_hidden_states_p, max_query_len_p, seq_lens_p, cu_num_tokens_p) = (
|
||||
self._split_pcp_input(req_scheduled_tokens_p, input_ids_p, target_hidden_states_p)
|
||||
)
|
||||
num_tokens = num_tokens_d + num_tokens_p
|
||||
target_positions = target_positions[:num_tokens]
|
||||
self.input_ids[:num_tokens].copy_(torch.cat([input_ids_d, input_ids_p], dim=0))
|
||||
target_hidden_states = torch.cat([target_hidden_states_d, target_hidden_states_p], dim=0)
|
||||
# 2. update sample_indices according to main model
|
||||
if num_decode_reqs:
|
||||
last_token_indices[:num_decode_reqs] = self.runner.logits_indices[last_token_indices[:num_decode_reqs]]
|
||||
if num_prefill_reqs:
|
||||
last_token_indices[-num_prefill_reqs:] = self.runner.logits_indices[-num_prefill_reqs:]
|
||||
# 3. update attn_metadata params that may be influenced by pcp
|
||||
common_attn_metadata.num_actual_tokens = num_tokens
|
||||
common_attn_metadata.max_query_len = max(self.decode_threshold, max_query_len_p)
|
||||
common_attn_metadata.seq_lens[-num_prefill_reqs:] = seq_lens_p
|
||||
common_attn_metadata.seq_lens_cpu[-num_prefill_reqs:] = seq_lens_p
|
||||
query_start_loc_p = cu_num_tokens_p[1:] + common_attn_metadata.query_start_loc[num_decode_reqs].item()
|
||||
common_attn_metadata.query_start_loc[-num_prefill_reqs:] = query_start_loc_p
|
||||
common_attn_metadata.query_start_loc_cpu[-num_prefill_reqs:] = query_start_loc_p
|
||||
if self.use_cuda_graph and num_tokens <= self.runner.cudagraph_batch_sizes[-1]:
|
||||
num_input_tokens = self.runner.cudagraph_dispatcher._bs_to_padded_graph_size[num_tokens]
|
||||
if not (
|
||||
@@ -468,11 +534,7 @@ class EagleProposer(VllmEagleProposer):
|
||||
# only tensor which will be used in current FIA.
|
||||
# Strictly speaking, `query_start_loc`, `seq_lens` should also have
|
||||
# their memory allocated separately for each step just like `slot_mapping`.
|
||||
slot_mapping_lens = (
|
||||
num_input_tokens
|
||||
if num_input_tokens < common_attn_metadata.slot_mapping.shape[0]
|
||||
else common_attn_metadata.slot_mapping.shape[0]
|
||||
)
|
||||
slot_mapping_lens = common_attn_metadata.slot_mapping.shape[0]
|
||||
self.slot_mapping_group[0][:slot_mapping_lens].copy_(common_attn_metadata.slot_mapping[:slot_mapping_lens])
|
||||
self.slot_mapping_group[0][slot_mapping_lens:].fill_(-1)
|
||||
common_attn_metadata.slot_mapping = self.slot_mapping_group[0][:slot_mapping_lens]
|
||||
@@ -491,21 +553,87 @@ class EagleProposer(VllmEagleProposer):
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
multi_steps_attn_metadata = [per_layer_attn_metadata]
|
||||
|
||||
# Copy the old attn_metadata and update
|
||||
for draft_step in range(1, self.num_speculative_tokens):
|
||||
common_attn_metadata, attn_metadata = self.attn_update_stack_num_spec_norm(
|
||||
draft_step,
|
||||
attn_metadata,
|
||||
common_attn_metadata,
|
||||
batch_size,
|
||||
num_input_tokens,
|
||||
used_update_positions,
|
||||
aclgraph_runtime_mode,
|
||||
)
|
||||
per_layer_attn_metadata = dict()
|
||||
for layer_name in self.attn_layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
multi_steps_attn_metadata.append(per_layer_attn_metadata)
|
||||
attn_metadata_i = per_layer_attn_metadata[self.attn_layer_names[0]]
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
if self.num_speculative_tokens > 1 and not attn_metadata_i.num_prefills:
|
||||
# For pcp/dcp, tokens are split across different cp ranks,
|
||||
# so we can not simply update slot_mapping by += 1.
|
||||
# Instead, we pre-allocate mtp slot_mapping in model_runner
|
||||
# (_generate_pcp_mtp_input), and use updated slot_indices
|
||||
# to get corresponding slot_mapping in each step.
|
||||
num_reject_tokens = (
|
||||
torch.tensor(self.runner.pcp_manager.cu_num_tokens_pcp_full, dtype=torch.int32).to(self.device)
|
||||
- ori_last_token_indices
|
||||
- 1
|
||||
)
|
||||
num_accept_tokens = query_lens_d.to(self.device) - num_reject_tokens
|
||||
ori_seq_len = attn_metadata_i.seq_lens[:batch_size].clone()
|
||||
mtp_slot_mapping = self.runner.pcp_manager.mtp_slot_pad
|
||||
|
||||
# slot_mapping index base offset:
|
||||
# scheduled tokens + pre-allocated mtp tokens + accepted tokens
|
||||
slot_idx_base = (
|
||||
torch.cat(
|
||||
[
|
||||
torch.tensor([0], dtype=torch.int32, device=self.device),
|
||||
(torch.cumsum(query_lens_d, dim=0)[:-1] * self.pcp_size).to(self.device),
|
||||
]
|
||||
)
|
||||
+ torch.arange(num_decode_reqs, device=self.device)
|
||||
* (self.num_speculative_tokens - 1)
|
||||
* self.pcp_size
|
||||
+ (num_accept_tokens - 1) * self.pcp_size
|
||||
)
|
||||
slot_indices_list = []
|
||||
for req_id in range(num_decode_reqs):
|
||||
slot_indices_list.append(
|
||||
torch.arange(slot_idx_base[req_id], slot_idx_base[req_id] + self.pcp_size, device=self.device)
|
||||
)
|
||||
slot_indices = torch.cat(slot_indices_list, dim=0)
|
||||
|
||||
# fold block_table (restore it to original size before flattened)
|
||||
block_indices = torch.cat(
|
||||
[torch.tensor([0], dtype=torch.int32), torch.cumsum(query_lens_d, dim=0)[:-1]]
|
||||
)
|
||||
common_attn_metadata.block_table_tensor[:batch_size] = common_attn_metadata.block_table_tensor[
|
||||
block_indices
|
||||
]
|
||||
common_attn_metadata.block_table_tensor = common_attn_metadata.block_table_tensor[:batch_size]
|
||||
|
||||
# Copy the old attn_metadata and update
|
||||
for draft_step in range(1, self.num_speculative_tokens):
|
||||
common_attn_metadata, attn_metadata = self.attn_update_stack_num_spec_norm(
|
||||
draft_step,
|
||||
attn_metadata,
|
||||
common_attn_metadata,
|
||||
batch_size,
|
||||
num_input_tokens,
|
||||
used_update_positions,
|
||||
aclgraph_runtime_mode,
|
||||
ori_seq_len,
|
||||
slot_indices,
|
||||
mtp_slot_mapping,
|
||||
)
|
||||
per_layer_attn_metadata = dict()
|
||||
for layer_name in self.attn_layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
multi_steps_attn_metadata.append(per_layer_attn_metadata)
|
||||
else:
|
||||
# Copy the old attn_metadata and update
|
||||
for draft_step in range(1, self.num_speculative_tokens):
|
||||
common_attn_metadata, attn_metadata = self.attn_update_stack_num_spec_norm(
|
||||
draft_step,
|
||||
attn_metadata,
|
||||
common_attn_metadata,
|
||||
batch_size,
|
||||
num_input_tokens,
|
||||
used_update_positions,
|
||||
aclgraph_runtime_mode,
|
||||
)
|
||||
per_layer_attn_metadata = dict()
|
||||
for layer_name in self.attn_layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
multi_steps_attn_metadata.append(per_layer_attn_metadata)
|
||||
|
||||
last_token_indices_len = last_token_indices.shape[0]
|
||||
self.last_token_indices[:last_token_indices_len].copy_(last_token_indices)
|
||||
@@ -533,6 +661,8 @@ class EagleProposer(VllmEagleProposer):
|
||||
target_positions=target_positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
multi_steps_attn_metadata=multi_steps_attn_metadata,
|
||||
num_tokens=num_tokens,
|
||||
is_prefill=attn_metadata_i.num_prefills,
|
||||
)
|
||||
|
||||
forward_context = get_forward_context()
|
||||
@@ -548,7 +678,9 @@ class EagleProposer(VllmEagleProposer):
|
||||
target_positions,
|
||||
inputs_embeds,
|
||||
multi_steps_attn_metadata,
|
||||
num_tokens,
|
||||
is_dummy=False,
|
||||
is_prefill=None,
|
||||
) -> torch.Tensor:
|
||||
# The lifecycle of `input_ids`, `positions`, `hidden_states` runs through all
|
||||
# speculative tokens' proposings. `model_input_ids`, `model_positions` and
|
||||
@@ -575,6 +707,15 @@ class EagleProposer(VllmEagleProposer):
|
||||
last_hidden_states, model_positions, hidden_states
|
||||
)
|
||||
|
||||
if self.pcp_size > 1:
|
||||
# remove graph padding before all_gather
|
||||
hidden_states = hidden_states[:num_tokens]
|
||||
hidden_states = get_pcp_group().all_gather(hidden_states, 0)
|
||||
hidden_states = torch.index_select(
|
||||
hidden_states, 0, self.runner.pcp_manager.pcp_allgather_restore_idx.gpu[: hidden_states.shape[0]]
|
||||
)
|
||||
last_hidden_states = hidden_states # TODO: check it
|
||||
|
||||
num_indices = last_token_indices.shape[0]
|
||||
if lmhead_tp_enable() and not is_dummy:
|
||||
max_num_reqs_across_dp = (
|
||||
@@ -596,6 +737,13 @@ class EagleProposer(VllmEagleProposer):
|
||||
# [batch_size, 1]
|
||||
return draft_token_ids.view(-1, 1)
|
||||
|
||||
if self.pcp_size * self.dcp_size > 1 and is_prefill:
|
||||
draft_token_ids = logits.argmax(dim=-1)
|
||||
draft_token_ids_list = []
|
||||
for _ in range(self.num_speculative_tokens):
|
||||
draft_token_ids_list.append(draft_token_ids)
|
||||
return torch.stack(draft_token_ids_list, dim=1)
|
||||
|
||||
# Generate the remaining draft tokens.
|
||||
draft_token_ids_tensor = torch.zeros(
|
||||
(self.num_speculative_tokens, *draft_token_ids.shape), dtype=draft_token_ids.dtype, device=self.device
|
||||
@@ -722,6 +870,9 @@ class EagleProposer(VllmEagleProposer):
|
||||
input_batch_size,
|
||||
used_update_positions,
|
||||
aclgraph_runtime_mode,
|
||||
ori_seq_len=None,
|
||||
slot_indices=None,
|
||||
mtp_slot_mapping=None,
|
||||
):
|
||||
assert draft_step > 0
|
||||
common_attn_metadata = self.shallow_copy_metadata(old_common_metadata)
|
||||
@@ -797,28 +948,42 @@ class EagleProposer(VllmEagleProposer):
|
||||
attn_metadata_builder = self._get_attention_metadata_builder()
|
||||
else:
|
||||
attn_metadata_builder = self.attn_metadata_builder
|
||||
block_size = attn_metadata_builder.kv_cache_spec.block_size
|
||||
|
||||
# Compute the slot mapping.
|
||||
if self.uses_mrope:
|
||||
block_numbers = clamped_positions[0] // block_size
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
num_computed_tokens_of_pcp_dcp = self.runner.pcp_manager._get_cp_local_seq_lens(
|
||||
ori_seq_len + draft_step,
|
||||
self.pcp_size,
|
||||
self.dcp_size,
|
||||
self.runner.parallel_config.cp_kv_cache_interleave_size,
|
||||
)
|
||||
cp_seq_len = num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank]
|
||||
# update slot_mapping
|
||||
slot_indices += self.pcp_size
|
||||
slot_mapping = mtp_slot_mapping[slot_indices]
|
||||
common_attn_metadata.slot_mapping[: batch_size * self.pcp_size] = slot_mapping
|
||||
else:
|
||||
block_numbers = clamped_positions // block_size
|
||||
block_ids = old_common_metadata.block_table_tensor.gather(dim=1, index=block_numbers.view(-1, 1))
|
||||
block_ids = block_ids.view(-1)
|
||||
if self.uses_mrope:
|
||||
slot_mapping = block_ids * block_size + clamped_positions[0] % block_size
|
||||
else:
|
||||
slot_mapping = block_ids * block_size + clamped_positions % block_size
|
||||
block_size = attn_metadata_builder.kv_cache_spec.block_size
|
||||
|
||||
# Mask out the slot mappings that exceed the max model length.
|
||||
# Otherwise, the KV cache will be inadvertently updated with the
|
||||
# padding tokens.
|
||||
slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID)
|
||||
self.slot_mapping_group[draft_step][: slot_mapping.shape[0]].copy_(slot_mapping.to(torch.int32))
|
||||
self.slot_mapping_group[draft_step][slot_mapping.shape[0] :].fill_(PADDING_SLOT_ID)
|
||||
# Set the address of the attn_metadata.slot_mapping to the self.slot_mapping_group[idx]
|
||||
common_attn_metadata.slot_mapping = self.slot_mapping_group[draft_step][: slot_mapping.shape[0]]
|
||||
# Compute the slot mapping.
|
||||
if self.uses_mrope:
|
||||
block_numbers = clamped_positions[0] // block_size
|
||||
else:
|
||||
block_numbers = clamped_positions // block_size
|
||||
block_ids = old_common_metadata.block_table_tensor.gather(dim=1, index=block_numbers.view(-1, 1))
|
||||
block_ids = block_ids.view(-1)
|
||||
if self.uses_mrope:
|
||||
slot_mapping = block_ids * block_size + clamped_positions[0] % block_size
|
||||
else:
|
||||
slot_mapping = block_ids * block_size + clamped_positions % block_size
|
||||
|
||||
# Mask out the slot mappings that exceed the max model length.
|
||||
# Otherwise, the KV cache will be inadvertently updated with the
|
||||
# padding tokens.
|
||||
slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID)
|
||||
self.slot_mapping_group[draft_step][: slot_mapping.shape[0]].copy_(slot_mapping.to(torch.int32))
|
||||
self.slot_mapping_group[draft_step][slot_mapping.shape[0] :].fill_(PADDING_SLOT_ID)
|
||||
# Set the address of the attn_metadata.slot_mapping to the self.slot_mapping_group[idx]
|
||||
common_attn_metadata.slot_mapping = self.slot_mapping_group[draft_step][: slot_mapping.shape[0]]
|
||||
|
||||
# Rebuild attention metadata
|
||||
attn_metadata = attn_metadata_builder.build_for_drafting( # type: ignore
|
||||
@@ -826,6 +991,12 @@ class EagleProposer(VllmEagleProposer):
|
||||
draft_index=draft_step,
|
||||
)
|
||||
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
if self.vllm_config.model_config.use_mla:
|
||||
attn_metadata.decode.cp_seq_len = cp_seq_len
|
||||
else:
|
||||
attn_metadata.decode_meta.num_computed_tokens_of_pcp_dcp = num_computed_tokens_of_pcp_dcp
|
||||
|
||||
return common_attn_metadata, attn_metadata
|
||||
|
||||
def prepare_next_token_ids_padded(
|
||||
|
||||
@@ -39,11 +39,7 @@ class MtpProposer(EagleProposer):
|
||||
# Currently, both GLM and DS encounter issues when enabling the fullgraph mode and running on EagleProposer.
|
||||
# Therefore, we temporarily bypass this problem by adding a conditional check for fullgraph.
|
||||
# TODO: this conditional check should be removed after bug fixing.
|
||||
if (
|
||||
self.pcp_size * self.dcp_size == 1
|
||||
and not self.speculative_config.disable_padded_drafter_batch
|
||||
and not self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
):
|
||||
if not self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
super().dummy_run(
|
||||
num_tokens,
|
||||
with_prefill,
|
||||
@@ -175,11 +171,7 @@ class MtpProposer(EagleProposer):
|
||||
# Currently, both GLM and DS encounter issues when enabling the fullgraph mode and running on EagleProposer.
|
||||
# Therefore, we temporarily bypass this problem by adding a conditional check for fullgraph.
|
||||
# TODO: this conditional check should be removed after bug fixing.
|
||||
if (
|
||||
self.pcp_size * self.dcp_size == 1
|
||||
and not self.speculative_config.disable_padded_drafter_batch
|
||||
and not self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
):
|
||||
if not self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
draft_token_ids = super()._propose(
|
||||
target_token_ids,
|
||||
target_positions,
|
||||
|
||||
@@ -561,7 +561,6 @@ class NPUModelRunner(GPUModelRunner):
|
||||
dtype=np.int32,
|
||||
)
|
||||
attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens, num_valid_tokens)
|
||||
self.attn_state = attn_state # type: ignore
|
||||
|
||||
# Determine if it's a splitfuse batch
|
||||
with_prefill = attn_state not in [AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding]
|
||||
@@ -800,7 +799,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
attn_state = AscendAttentionState.SpecDecoding
|
||||
# Speculative decoding.
|
||||
elif np.all(num_valid_tokens == 1):
|
||||
if self.speculative_config and self.speculative_config.method == "mtp":
|
||||
if self.speculative_config:
|
||||
attn_state = AscendAttentionState.SpecDecoding
|
||||
else:
|
||||
attn_state = AscendAttentionState.ChunkedPrefill
|
||||
@@ -809,6 +808,14 @@ class NPUModelRunner(GPUModelRunner):
|
||||
attn_state = AscendAttentionState.ChunkedPrefill
|
||||
else:
|
||||
attn_state = AscendAttentionState.PrefillCacheHit
|
||||
|
||||
# For the overlay of the PCP feature and the eagle3, attn_state needs to be recovered
|
||||
# TODO: Resolved the conflict between the sunset of attn_state and the PCP that requires this interface.
|
||||
if attn_state == AscendAttentionState.SpecDecoding and self.speculative_config.method != "mtp":
|
||||
self.attn_state = AscendAttentionState.ChunkedPrefill # type: ignore
|
||||
else:
|
||||
self.attn_state = attn_state # type: ignore
|
||||
|
||||
return attn_state
|
||||
|
||||
def _calc_spec_decode_metadata(
|
||||
@@ -977,6 +984,10 @@ class NPUModelRunner(GPUModelRunner):
|
||||
target_token_ids = input_ids_pcp_full[:num_scheduled_tokens]
|
||||
target_positions = self._get_positions(num_scheduled_tokens)
|
||||
target_hidden_states = hidden_states
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
target_hidden_states = torch.cat(
|
||||
[h[:num_scheduled_tokens] for h in aux_hidden_states], dim=-1
|
||||
)
|
||||
else:
|
||||
token_indices_to_sample = None
|
||||
# input_ids can be None for multimodal models.
|
||||
@@ -1014,6 +1025,8 @@ class NPUModelRunner(GPUModelRunner):
|
||||
target_token_ids = input_ids_pcp_full[token_indices]
|
||||
target_positions = positions
|
||||
target_hidden_states = hidden_states
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
target_hidden_states = torch.cat([h[token_indices] for h in aux_hidden_states], dim=-1)
|
||||
else:
|
||||
target_token_ids = self.input_ids.gpu[token_indices]
|
||||
target_positions = self._get_positions(token_indices)
|
||||
@@ -1260,13 +1273,18 @@ class NPUModelRunner(GPUModelRunner):
|
||||
num_tokens_padded, input_ids, positions, intermediate_tensors, inputs_embeds, **model_kwargs
|
||||
)
|
||||
with record_function_or_nullcontext("post process"):
|
||||
aux_hidden_states = None
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
hidden_states, aux_hidden_states = hidden_states
|
||||
if self.pcp_size > 1:
|
||||
# NOTE we must `slice` hidden_states because pcp_allgather_restore_idx
|
||||
# ignores the padding from CUDA Graph.
|
||||
hidden_states = self.pcp_manager.get_restore_hidden_states(hidden_states)
|
||||
aux_hidden_states = None
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
hidden_states, aux_hidden_states = hidden_states
|
||||
if aux_hidden_states is not None:
|
||||
aux_hidden_states = [
|
||||
self.pcp_manager.get_restore_hidden_states(aux_hidden_states_pcp)
|
||||
for aux_hidden_states_pcp in aux_hidden_states
|
||||
]
|
||||
|
||||
if not self.broadcast_pp_output:
|
||||
# Common case.
|
||||
|
||||
Reference in New Issue
Block a user