[eagle3][pcp] fix bug for eagle3 and cp enable (#7309)

### What this PR does / why we need it?
This PR fixes the bug for eagle3 and cp enable introduced by the
parallel speculative inference PR.

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?
tests and ut

- vLLM version: v0.17.0
- vLLM main:
4034c3d32e

---------

Signed-off-by: lilinsiman <lilinsiman@gmail.com>
This commit is contained in:
lilinsiman
2026-03-17 16:14:45 +08:00
committed by GitHub
parent 4e62a2ae15
commit 8f278fc101
2 changed files with 105 additions and 63 deletions

View File

@@ -29,6 +29,10 @@ prompts = [
"The president of United States is", "AI future is"
]
model = "wemaster/deepseek_mtp_main_random_bf16"
model_eagle3 = {
"main": "Qwen/Qwen3-8B",
"spec": "RedHatAI/Qwen3-8B-speculator.eagle3",
}
@wait_until_npu_memory_free()
def test_pcp_dcp_mtp1_eager():
@@ -141,3 +145,24 @@ def test_dcp_mtp3_full_graph():
async_scheduling=False,
) as runner:
runner.generate_greedy(prompts, 32)
@wait_until_npu_memory_free()
def test_pcp_eagle3_eager():
with VllmRunner(
model_eagle3["main"],
max_model_len=1024,
tensor_parallel_size=2,
enforce_eager=True,
prefill_context_parallel_size=2,
decode_context_parallel_size=1,
max_num_batched_tokens=1024,
block_size=128,
speculative_config={
"num_speculative_tokens": 3,
"method": "eagle3",
"model": model_eagle3["spec"]
},
async_scheduling=False,
) as runner:
runner.generate_greedy(prompts, 32)

View File

@@ -475,7 +475,7 @@ class SpecDecodeBaseProposer(EagleProposer):
target_hidden_states = self.model.combine_hidden_states(target_hidden_states)
assert target_hidden_states.shape[-1] == self.hidden_size
num_tokens, token_indices_to_sample, common_attn_metadata = self.set_inputs_first_pass(
num_tokens, token_indices_to_sample, common_attn_metadata, long_seq_args = self.set_inputs_first_pass(
target_token_ids=target_token_ids,
next_token_ids=next_token_ids,
target_positions=target_positions,
@@ -483,65 +483,15 @@ class SpecDecodeBaseProposer(EagleProposer):
token_indices_to_sample=token_indices_to_sample,
cad=common_attn_metadata,
num_rejected_tokens_gpu=num_rejected_tokens_gpu,
req_scheduled_tokens=req_scheduled_tokens,
long_seq_metadata=long_seq_metadata,
num_prefill_reqs=num_prefill_reqs,
num_decode_reqs=num_decode_reqs,
)
assert self.runner is not None
# update pcp related params
if self.pcp_size * self.dcp_size > 1:
assert long_seq_metadata is not None
common_attn_metadata.prefill_context_parallel_metadata = long_seq_metadata
ori_token_indices_to_sample = token_indices_to_sample.clone()
query_lens_d = self.runner.query_lens[:num_decode_reqs]
if self.pcp_size > 1:
# 1. preprocess decode/prefill input_ids & target_hidden_states
# decode input_ids: keep unchanged
# decode target_hidden_states: remove padding
# prefill input_ids: add padding and pcp split
# prefill target_hidden_states: pcp split
num_tokens_d = query_lens_d.sum().item()
num_tokens_d_padded = num_tokens_d * self.pcp_size
input_ids_d = self.input_ids[:num_tokens_d]
input_ids_p = self.input_ids[num_tokens_d:num_tokens]
target_hidden_states_d_padded = target_hidden_states[:num_tokens_d_padded]
if num_tokens_d:
# remove padding (from pcp all-gather) in decode part
mask_start_loc = torch.cat(
[torch.tensor([0], dtype=torch.int32), torch.cumsum(query_lens_d * self.pcp_size, dim=0)[:-1]]
)
mask_len = query_lens_d
mask = []
for req_id in range(num_decode_reqs):
mask += list(range(mask_start_loc[req_id], mask_start_loc[req_id] + mask_len[req_id]))
target_hidden_states_d = target_hidden_states_d_padded[mask]
else:
target_hidden_states_d = target_hidden_states_d_padded
target_hidden_states_p = target_hidden_states[num_tokens_d_padded:]
req_scheduled_tokens_p = {}
for i, req_id in enumerate(self.runner.input_batch.req_ids):
if i >= num_decode_reqs:
req_scheduled_tokens_p[req_id] = req_scheduled_tokens[req_id]
(num_tokens_p, input_ids_p, target_hidden_states_p, max_query_len_p, seq_lens_p, cu_num_tokens_p) = (
self._split_pcp_input(req_scheduled_tokens_p, input_ids_p, target_hidden_states_p)
)
num_tokens = num_tokens_d + num_tokens_p
target_positions = target_positions[:num_tokens]
self.input_ids[:num_tokens].copy_(torch.cat([input_ids_d, input_ids_p], dim=0))
target_hidden_states = torch.cat([target_hidden_states_d, target_hidden_states_p], dim=0)
# 2. update sample_indices according to main model
if num_decode_reqs:
token_indices_to_sample[:num_decode_reqs] = self.runner.logits_indices[
token_indices_to_sample[:num_decode_reqs]
]
if num_prefill_reqs:
token_indices_to_sample[-num_prefill_reqs:] = self.runner.logits_indices[-num_prefill_reqs:]
# 3. update attn_metadata params that may be influenced by pcp
common_attn_metadata.num_actual_tokens = num_tokens
common_attn_metadata.max_query_len = max(self.decode_threshold, max_query_len_p)
common_attn_metadata.seq_lens[-num_prefill_reqs:] = seq_lens_p
common_attn_metadata.seq_lens_cpu[-num_prefill_reqs:] = seq_lens_p
query_start_loc_p = cu_num_tokens_p[1:] + common_attn_metadata.query_start_loc[num_decode_reqs].item()
common_attn_metadata.query_start_loc[-num_prefill_reqs:] = query_start_loc_p
common_attn_metadata.query_start_loc_cpu[-num_prefill_reqs:] = query_start_loc_p
assert long_seq_args is not None
query_lens_d, ori_token_indices_to_sample = long_seq_args
assert self.runner is not None
if self.use_cuda_graph and num_tokens <= self.runner.cudagraph_batch_sizes[-1]:
num_input_tokens = self.runner.cudagraph_dispatcher._bs_to_padded_graph_size[num_tokens]
if not (
@@ -986,7 +936,11 @@ class SpecDecodeBaseProposer(EagleProposer):
token_indices_to_sample: torch.Tensor | None,
cad: CommonAttentionMetadata,
num_rejected_tokens_gpu: torch.Tensor | None,
) -> tuple[int, torch.Tensor, CommonAttentionMetadata]:
req_scheduled_tokens=None,
long_seq_metadata=None,
num_prefill_reqs=0,
num_decode_reqs=0,
) -> tuple[int, torch.Tensor, CommonAttentionMetadata, tuple[Any, Any] | None]:
if not self.needs_extra_input_slots:
# Default EAGLE pathway: no reshaping of input tensors needed.
# Simply rotate the input ids and leave the positions unchanged,
@@ -1002,6 +956,68 @@ class SpecDecodeBaseProposer(EagleProposer):
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
self.input_ids[token_indices_to_sample] = next_token_ids
assert self.runner is not None
# update pcp related params
ori_token_indices_to_sample = None
query_lens_d = None
if self.pcp_size * self.dcp_size > 1:
assert long_seq_metadata is not None
cad.prefill_context_parallel_metadata = long_seq_metadata
ori_token_indices_to_sample = token_indices_to_sample.clone()
query_lens_d = self.runner.query_lens[:num_decode_reqs]
if self.pcp_size > 1:
# 1. preprocess decode/prefill input_ids & target_hidden_states
# decode input_ids: keep unchanged
# decode target_hidden_states: remove padding
# prefill input_ids: add padding and pcp split
# prefill target_hidden_states: pcp split
assert query_lens_d is not None
num_tokens_d = query_lens_d.sum().item()
num_tokens_d_padded = num_tokens_d * self.pcp_size
input_ids_d = self.input_ids[:num_tokens_d]
input_ids_p = self.input_ids[num_tokens_d:num_tokens]
target_hidden_states_d_padded = target_hidden_states[:num_tokens_d_padded]
if num_tokens_d:
# remove padding (from pcp all-gather) in decode part
mask_start_loc = torch.cat(
[torch.tensor([0], dtype=torch.int32), torch.cumsum(query_lens_d * self.pcp_size, dim=0)[:-1]]
)
mask_len = query_lens_d
mask = []
for req_id in range(num_decode_reqs):
assert None not in (mask_start_loc, mask_len)
mask += list(range(mask_start_loc[req_id], mask_start_loc[req_id] + mask_len[req_id]))
target_hidden_states_d = target_hidden_states_d_padded[mask]
else:
target_hidden_states_d = target_hidden_states_d_padded
target_hidden_states_p = target_hidden_states[num_tokens_d_padded:]
req_scheduled_tokens_p = {}
for i, req_id in enumerate(self.runner.input_batch.req_ids):
if i >= num_decode_reqs:
req_scheduled_tokens_p[req_id] = req_scheduled_tokens[req_id]
(num_tokens_p, input_ids_p, target_hidden_states_p, max_query_len_p, seq_lens_p, cu_num_tokens_p) = (
self._split_pcp_input(req_scheduled_tokens_p, input_ids_p, target_hidden_states_p)
)
num_tokens = num_tokens_d + num_tokens_p
target_positions = target_positions[:num_tokens]
self.input_ids[:num_tokens].copy_(torch.cat([input_ids_d, input_ids_p], dim=0))
target_hidden_states = torch.cat([target_hidden_states_d, target_hidden_states_p], dim=0)
# 2. update sample_indices according to main model
if num_decode_reqs:
token_indices_to_sample[:num_decode_reqs] = self.runner.logits_indices[
token_indices_to_sample[:num_decode_reqs]
]
if num_prefill_reqs:
token_indices_to_sample[-num_prefill_reqs:] = self.runner.logits_indices[-num_prefill_reqs:]
# 3. update attn_metadata params that may be influenced by pcp
cad.num_actual_tokens = num_tokens
cad.max_query_len = max(self.decode_threshold, max_query_len_p)
cad.seq_lens[-num_prefill_reqs:] = seq_lens_p
cad.seq_lens_cpu[-num_prefill_reqs:] = seq_lens_p
query_start_loc_p = cu_num_tokens_p[1:] + cad.query_start_loc[num_decode_reqs].item()
cad.query_start_loc[-num_prefill_reqs:] = query_start_loc_p
cad.query_start_loc_cpu[-num_prefill_reqs:] = query_start_loc_p
# copy inputs to buffer for cudagraph
if self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim == 0:
target_positions = target_positions[0]
@@ -1009,7 +1025,7 @@ class SpecDecodeBaseProposer(EagleProposer):
self._set_positions(num_tokens, target_positions)
self.hidden_states[:num_tokens] = target_hidden_states
return num_tokens, token_indices_to_sample, cad
return num_tokens, token_indices_to_sample, cad, (query_lens_d, ori_token_indices_to_sample)
else:
assert self.is_rejected_token_mask is not None
assert self.is_masked_token_mask is not None
@@ -1057,7 +1073,7 @@ class SpecDecodeBaseProposer(EagleProposer):
# Use torch.where to avoid DtoH sync from boolean indexing
mask = self.is_masked_token_mask[:total_num_output_tokens]
torch.where(
mask.unsqueeze(1),
mask.unsqueeze(1), # type: ignore
self.parallel_drafting_hidden_state_tensor,
self.hidden_states[:total_num_output_tokens],
out=self.hidden_states[:total_num_output_tokens],
@@ -1093,7 +1109,7 @@ class SpecDecodeBaseProposer(EagleProposer):
new_slot_mapping=new_slot_mapping,
)
return total_num_output_tokens, token_indices_to_sample, new_cad
return total_num_output_tokens, token_indices_to_sample, new_cad, None
def model_returns_tuple(self) -> bool:
return self.method not in ("mtp", "draft_model")
@@ -1198,7 +1214,8 @@ class SpecDecodeBaseProposer(EagleProposer):
# update slot_mapping
slot_indices += self.pcp_size
slot_mapping = mtp_slot_mapping[slot_indices]
common_attn_metadata.slot_mapping[: batch_size * self.pcp_size] = slot_mapping
self.slot_mapping_group[draft_step][: batch_size * self.pcp_size] = slot_mapping
common_attn_metadata.slot_mapping = self.slot_mapping_group[draft_step]
else:
# NOTE: In vllm, `block_size = attn_metadata_builder.kv_cache_spec.block_size`.
# However, in vllm-ascend, the above value can be multiple of `kernel_block_size`,