diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index e784bbb6..17abdc42 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -347,7 +347,7 @@ class EagleProposer(VllmEagleProposer): model_positions = self._get_positions(num_tokens) - batch_size = num_tokens // (self.num_speculative_tokens + 1) + batch_size = num_tokens // (self.num_speculative_tokens + 1) if not is_profile else self.runner.max_num_reqs with set_ascend_forward_context( multi_steps_attn_metadata[0] if multi_steps_attn_metadata else None, @@ -613,7 +613,7 @@ class EagleProposer(VllmEagleProposer): hidden_states = hidden_states[last_token_indices] last_token_indices = self.arange[:batch_size] - input_batch_size = num_input_tokens + input_batch_size = num_input_tokens if (self.method == "mtp" or self.use_cuda_graph) else batch_size forward_context = get_forward_context() forward_context.num_tokens = input_batch_size