diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 35291ec5..403d52a2 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -439,7 +439,6 @@ class SpecDecodeBaseProposer(EagleProposer): target_positions=model_positions, inputs_embeds=None, multi_steps_attn_metadata=multi_steps_attn_metadata, - is_dummy=True, num_tokens=num_tokens, ) forward_context = get_forward_context() @@ -702,7 +701,6 @@ class SpecDecodeBaseProposer(EagleProposer): inputs_embeds, multi_steps_attn_metadata, num_tokens, - is_dummy=False, is_prefill=None, ) -> torch.Tensor: # The lifecycle of `input_ids`, `positions`, `hidden_states` runs through all @@ -755,7 +753,7 @@ class SpecDecodeBaseProposer(EagleProposer): self.runner.pcp_manager.pcp_allgather_restore_idx.gpu[: last_hidden_states.shape[0]], ) - if lmhead_tp_enable() and not is_dummy: + if lmhead_tp_enable(): max_num_reqs_across_dp = ( self.vllm_config.scheduler_config.max_num_seqs * self.runner.uniform_decode_query_len ) @@ -766,7 +764,7 @@ class SpecDecodeBaseProposer(EagleProposer): sample_hidden_states = last_hidden_states[token_indices_to_sample] logits = self.model.compute_logits(sample_hidden_states) - if lmhead_tp_enable() and num_indices < logits.shape[0] and not is_dummy: + if lmhead_tp_enable() and num_indices < logits.shape[0]: logits = logits[:num_indices] token_indices_to_sample = token_indices_to_sample[:num_indices] @@ -879,7 +877,7 @@ class SpecDecodeBaseProposer(EagleProposer): ) num_indices = token_indices_to_sample.shape[0] - if lmhead_tp_enable() and not is_dummy: + if lmhead_tp_enable(): max_num_reqs_across_dp = ( self.vllm_config.scheduler_config.max_num_seqs * self.runner.uniform_decode_query_len ) @@ -891,7 +889,7 @@ class SpecDecodeBaseProposer(EagleProposer): sample_hidden_states = last_hidden_states[token_indices_to_sample] logits = self.model.compute_logits(sample_hidden_states) - if lmhead_tp_enable() and num_indices < logits.shape[0] and not is_dummy: + if lmhead_tp_enable() and num_indices < logits.shape[0]: logits = logits[:num_indices] token_indices_to_sample = token_indices_to_sample[:num_indices]