[eagle3][pcp] fix bug for eagle3 and cp enable (#7309)
### What this PR does / why we need it?
This PR fixes the bug for eagle3 and cp enable introduced by the
parallel speculative inference PR.
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
tests and ut
- vLLM version: v0.17.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: lilinsiman <lilinsiman@gmail.com>
This commit is contained in:
@@ -29,6 +29,10 @@ prompts = [
|
||||
"The president of United States is", "AI future is"
|
||||
]
|
||||
model = "wemaster/deepseek_mtp_main_random_bf16"
|
||||
model_eagle3 = {
|
||||
"main": "Qwen/Qwen3-8B",
|
||||
"spec": "RedHatAI/Qwen3-8B-speculator.eagle3",
|
||||
}
|
||||
|
||||
@wait_until_npu_memory_free()
|
||||
def test_pcp_dcp_mtp1_eager():
|
||||
@@ -141,3 +145,24 @@ def test_dcp_mtp3_full_graph():
|
||||
async_scheduling=False,
|
||||
) as runner:
|
||||
runner.generate_greedy(prompts, 32)
|
||||
|
||||
|
||||
@wait_until_npu_memory_free()
|
||||
def test_pcp_eagle3_eager():
|
||||
with VllmRunner(
|
||||
model_eagle3["main"],
|
||||
max_model_len=1024,
|
||||
tensor_parallel_size=2,
|
||||
enforce_eager=True,
|
||||
prefill_context_parallel_size=2,
|
||||
decode_context_parallel_size=1,
|
||||
max_num_batched_tokens=1024,
|
||||
block_size=128,
|
||||
speculative_config={
|
||||
"num_speculative_tokens": 3,
|
||||
"method": "eagle3",
|
||||
"model": model_eagle3["spec"]
|
||||
},
|
||||
async_scheduling=False,
|
||||
) as runner:
|
||||
runner.generate_greedy(prompts, 32)
|
||||
|
||||
@@ -475,7 +475,7 @@ class SpecDecodeBaseProposer(EagleProposer):
|
||||
target_hidden_states = self.model.combine_hidden_states(target_hidden_states)
|
||||
assert target_hidden_states.shape[-1] == self.hidden_size
|
||||
|
||||
num_tokens, token_indices_to_sample, common_attn_metadata = self.set_inputs_first_pass(
|
||||
num_tokens, token_indices_to_sample, common_attn_metadata, long_seq_args = self.set_inputs_first_pass(
|
||||
target_token_ids=target_token_ids,
|
||||
next_token_ids=next_token_ids,
|
||||
target_positions=target_positions,
|
||||
@@ -483,65 +483,15 @@ class SpecDecodeBaseProposer(EagleProposer):
|
||||
token_indices_to_sample=token_indices_to_sample,
|
||||
cad=common_attn_metadata,
|
||||
num_rejected_tokens_gpu=num_rejected_tokens_gpu,
|
||||
req_scheduled_tokens=req_scheduled_tokens,
|
||||
long_seq_metadata=long_seq_metadata,
|
||||
num_prefill_reqs=num_prefill_reqs,
|
||||
num_decode_reqs=num_decode_reqs,
|
||||
)
|
||||
|
||||
assert self.runner is not None
|
||||
# update pcp related params
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
assert long_seq_metadata is not None
|
||||
common_attn_metadata.prefill_context_parallel_metadata = long_seq_metadata
|
||||
ori_token_indices_to_sample = token_indices_to_sample.clone()
|
||||
query_lens_d = self.runner.query_lens[:num_decode_reqs]
|
||||
if self.pcp_size > 1:
|
||||
# 1. preprocess decode/prefill input_ids & target_hidden_states
|
||||
# decode input_ids: keep unchanged
|
||||
# decode target_hidden_states: remove padding
|
||||
# prefill input_ids: add padding and pcp split
|
||||
# prefill target_hidden_states: pcp split
|
||||
num_tokens_d = query_lens_d.sum().item()
|
||||
num_tokens_d_padded = num_tokens_d * self.pcp_size
|
||||
input_ids_d = self.input_ids[:num_tokens_d]
|
||||
input_ids_p = self.input_ids[num_tokens_d:num_tokens]
|
||||
target_hidden_states_d_padded = target_hidden_states[:num_tokens_d_padded]
|
||||
if num_tokens_d:
|
||||
# remove padding (from pcp all-gather) in decode part
|
||||
mask_start_loc = torch.cat(
|
||||
[torch.tensor([0], dtype=torch.int32), torch.cumsum(query_lens_d * self.pcp_size, dim=0)[:-1]]
|
||||
)
|
||||
mask_len = query_lens_d
|
||||
mask = []
|
||||
for req_id in range(num_decode_reqs):
|
||||
mask += list(range(mask_start_loc[req_id], mask_start_loc[req_id] + mask_len[req_id]))
|
||||
target_hidden_states_d = target_hidden_states_d_padded[mask]
|
||||
else:
|
||||
target_hidden_states_d = target_hidden_states_d_padded
|
||||
target_hidden_states_p = target_hidden_states[num_tokens_d_padded:]
|
||||
req_scheduled_tokens_p = {}
|
||||
for i, req_id in enumerate(self.runner.input_batch.req_ids):
|
||||
if i >= num_decode_reqs:
|
||||
req_scheduled_tokens_p[req_id] = req_scheduled_tokens[req_id]
|
||||
(num_tokens_p, input_ids_p, target_hidden_states_p, max_query_len_p, seq_lens_p, cu_num_tokens_p) = (
|
||||
self._split_pcp_input(req_scheduled_tokens_p, input_ids_p, target_hidden_states_p)
|
||||
)
|
||||
num_tokens = num_tokens_d + num_tokens_p
|
||||
target_positions = target_positions[:num_tokens]
|
||||
self.input_ids[:num_tokens].copy_(torch.cat([input_ids_d, input_ids_p], dim=0))
|
||||
target_hidden_states = torch.cat([target_hidden_states_d, target_hidden_states_p], dim=0)
|
||||
# 2. update sample_indices according to main model
|
||||
if num_decode_reqs:
|
||||
token_indices_to_sample[:num_decode_reqs] = self.runner.logits_indices[
|
||||
token_indices_to_sample[:num_decode_reqs]
|
||||
]
|
||||
if num_prefill_reqs:
|
||||
token_indices_to_sample[-num_prefill_reqs:] = self.runner.logits_indices[-num_prefill_reqs:]
|
||||
# 3. update attn_metadata params that may be influenced by pcp
|
||||
common_attn_metadata.num_actual_tokens = num_tokens
|
||||
common_attn_metadata.max_query_len = max(self.decode_threshold, max_query_len_p)
|
||||
common_attn_metadata.seq_lens[-num_prefill_reqs:] = seq_lens_p
|
||||
common_attn_metadata.seq_lens_cpu[-num_prefill_reqs:] = seq_lens_p
|
||||
query_start_loc_p = cu_num_tokens_p[1:] + common_attn_metadata.query_start_loc[num_decode_reqs].item()
|
||||
common_attn_metadata.query_start_loc[-num_prefill_reqs:] = query_start_loc_p
|
||||
common_attn_metadata.query_start_loc_cpu[-num_prefill_reqs:] = query_start_loc_p
|
||||
assert long_seq_args is not None
|
||||
query_lens_d, ori_token_indices_to_sample = long_seq_args
|
||||
assert self.runner is not None
|
||||
if self.use_cuda_graph and num_tokens <= self.runner.cudagraph_batch_sizes[-1]:
|
||||
num_input_tokens = self.runner.cudagraph_dispatcher._bs_to_padded_graph_size[num_tokens]
|
||||
if not (
|
||||
@@ -986,7 +936,11 @@ class SpecDecodeBaseProposer(EagleProposer):
|
||||
token_indices_to_sample: torch.Tensor | None,
|
||||
cad: CommonAttentionMetadata,
|
||||
num_rejected_tokens_gpu: torch.Tensor | None,
|
||||
) -> tuple[int, torch.Tensor, CommonAttentionMetadata]:
|
||||
req_scheduled_tokens=None,
|
||||
long_seq_metadata=None,
|
||||
num_prefill_reqs=0,
|
||||
num_decode_reqs=0,
|
||||
) -> tuple[int, torch.Tensor, CommonAttentionMetadata, tuple[Any, Any] | None]:
|
||||
if not self.needs_extra_input_slots:
|
||||
# Default EAGLE pathway: no reshaping of input tensors needed.
|
||||
# Simply rotate the input ids and leave the positions unchanged,
|
||||
@@ -1002,6 +956,68 @@ class SpecDecodeBaseProposer(EagleProposer):
|
||||
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
|
||||
self.input_ids[token_indices_to_sample] = next_token_ids
|
||||
|
||||
assert self.runner is not None
|
||||
# update pcp related params
|
||||
ori_token_indices_to_sample = None
|
||||
query_lens_d = None
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
assert long_seq_metadata is not None
|
||||
cad.prefill_context_parallel_metadata = long_seq_metadata
|
||||
ori_token_indices_to_sample = token_indices_to_sample.clone()
|
||||
query_lens_d = self.runner.query_lens[:num_decode_reqs]
|
||||
if self.pcp_size > 1:
|
||||
# 1. preprocess decode/prefill input_ids & target_hidden_states
|
||||
# decode input_ids: keep unchanged
|
||||
# decode target_hidden_states: remove padding
|
||||
# prefill input_ids: add padding and pcp split
|
||||
# prefill target_hidden_states: pcp split
|
||||
assert query_lens_d is not None
|
||||
num_tokens_d = query_lens_d.sum().item()
|
||||
num_tokens_d_padded = num_tokens_d * self.pcp_size
|
||||
input_ids_d = self.input_ids[:num_tokens_d]
|
||||
input_ids_p = self.input_ids[num_tokens_d:num_tokens]
|
||||
target_hidden_states_d_padded = target_hidden_states[:num_tokens_d_padded]
|
||||
if num_tokens_d:
|
||||
# remove padding (from pcp all-gather) in decode part
|
||||
mask_start_loc = torch.cat(
|
||||
[torch.tensor([0], dtype=torch.int32), torch.cumsum(query_lens_d * self.pcp_size, dim=0)[:-1]]
|
||||
)
|
||||
mask_len = query_lens_d
|
||||
mask = []
|
||||
for req_id in range(num_decode_reqs):
|
||||
assert None not in (mask_start_loc, mask_len)
|
||||
mask += list(range(mask_start_loc[req_id], mask_start_loc[req_id] + mask_len[req_id]))
|
||||
target_hidden_states_d = target_hidden_states_d_padded[mask]
|
||||
else:
|
||||
target_hidden_states_d = target_hidden_states_d_padded
|
||||
target_hidden_states_p = target_hidden_states[num_tokens_d_padded:]
|
||||
req_scheduled_tokens_p = {}
|
||||
for i, req_id in enumerate(self.runner.input_batch.req_ids):
|
||||
if i >= num_decode_reqs:
|
||||
req_scheduled_tokens_p[req_id] = req_scheduled_tokens[req_id]
|
||||
(num_tokens_p, input_ids_p, target_hidden_states_p, max_query_len_p, seq_lens_p, cu_num_tokens_p) = (
|
||||
self._split_pcp_input(req_scheduled_tokens_p, input_ids_p, target_hidden_states_p)
|
||||
)
|
||||
num_tokens = num_tokens_d + num_tokens_p
|
||||
target_positions = target_positions[:num_tokens]
|
||||
self.input_ids[:num_tokens].copy_(torch.cat([input_ids_d, input_ids_p], dim=0))
|
||||
target_hidden_states = torch.cat([target_hidden_states_d, target_hidden_states_p], dim=0)
|
||||
# 2. update sample_indices according to main model
|
||||
if num_decode_reqs:
|
||||
token_indices_to_sample[:num_decode_reqs] = self.runner.logits_indices[
|
||||
token_indices_to_sample[:num_decode_reqs]
|
||||
]
|
||||
if num_prefill_reqs:
|
||||
token_indices_to_sample[-num_prefill_reqs:] = self.runner.logits_indices[-num_prefill_reqs:]
|
||||
# 3. update attn_metadata params that may be influenced by pcp
|
||||
cad.num_actual_tokens = num_tokens
|
||||
cad.max_query_len = max(self.decode_threshold, max_query_len_p)
|
||||
cad.seq_lens[-num_prefill_reqs:] = seq_lens_p
|
||||
cad.seq_lens_cpu[-num_prefill_reqs:] = seq_lens_p
|
||||
query_start_loc_p = cu_num_tokens_p[1:] + cad.query_start_loc[num_decode_reqs].item()
|
||||
cad.query_start_loc[-num_prefill_reqs:] = query_start_loc_p
|
||||
cad.query_start_loc_cpu[-num_prefill_reqs:] = query_start_loc_p
|
||||
|
||||
# copy inputs to buffer for cudagraph
|
||||
if self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim == 0:
|
||||
target_positions = target_positions[0]
|
||||
@@ -1009,7 +1025,7 @@ class SpecDecodeBaseProposer(EagleProposer):
|
||||
self._set_positions(num_tokens, target_positions)
|
||||
self.hidden_states[:num_tokens] = target_hidden_states
|
||||
|
||||
return num_tokens, token_indices_to_sample, cad
|
||||
return num_tokens, token_indices_to_sample, cad, (query_lens_d, ori_token_indices_to_sample)
|
||||
else:
|
||||
assert self.is_rejected_token_mask is not None
|
||||
assert self.is_masked_token_mask is not None
|
||||
@@ -1057,7 +1073,7 @@ class SpecDecodeBaseProposer(EagleProposer):
|
||||
# Use torch.where to avoid DtoH sync from boolean indexing
|
||||
mask = self.is_masked_token_mask[:total_num_output_tokens]
|
||||
torch.where(
|
||||
mask.unsqueeze(1),
|
||||
mask.unsqueeze(1), # type: ignore
|
||||
self.parallel_drafting_hidden_state_tensor,
|
||||
self.hidden_states[:total_num_output_tokens],
|
||||
out=self.hidden_states[:total_num_output_tokens],
|
||||
@@ -1093,7 +1109,7 @@ class SpecDecodeBaseProposer(EagleProposer):
|
||||
new_slot_mapping=new_slot_mapping,
|
||||
)
|
||||
|
||||
return total_num_output_tokens, token_indices_to_sample, new_cad
|
||||
return total_num_output_tokens, token_indices_to_sample, new_cad, None
|
||||
|
||||
def model_returns_tuple(self) -> bool:
|
||||
return self.method not in ("mtp", "draft_model")
|
||||
@@ -1198,7 +1214,8 @@ class SpecDecodeBaseProposer(EagleProposer):
|
||||
# update slot_mapping
|
||||
slot_indices += self.pcp_size
|
||||
slot_mapping = mtp_slot_mapping[slot_indices]
|
||||
common_attn_metadata.slot_mapping[: batch_size * self.pcp_size] = slot_mapping
|
||||
self.slot_mapping_group[draft_step][: batch_size * self.pcp_size] = slot_mapping
|
||||
common_attn_metadata.slot_mapping = self.slot_mapping_group[draft_step]
|
||||
else:
|
||||
# NOTE: In vllm, `block_size = attn_metadata_builder.kv_cache_spec.block_size`.
|
||||
# However, in vllm-ascend, the above value can be multiple of `kernel_block_size`,
|
||||
|
||||
Reference in New Issue
Block a user