[EAGLE] many fixes for eagle (#4195)

Co-authored-by: SangBin Cho <rkooo567@gmail.com>
Co-authored-by: Sehoon Kim <sehoon@x.ai>
This commit is contained in:
Lianmin Zheng
2025-03-07 22:12:13 -08:00
parent d052f4c8a9
commit d4017a6b63
15 changed files with 202 additions and 135 deletions

View File

@@ -928,45 +928,6 @@ class ModelRunner:
sampling_info.update_regex_vocab_mask()
sampling_info.apply_logits_bias(logits_output.next_token_logits)
def update_output_logprobs(
self,
logits_output: LogitsProcessorOutput,
sampling_info: SamplingBatchInfo,
top_logprobs_nums: List[int],
token_ids_logprobs: List[int],
next_token_ids: torch.Tensor,
*,
num_tokens_per_req: List[int],
):
"""Update the logits_output's output logprob based on next_token_ids
Args:
logits_output: The logits output from the model forward
sampling_info: Sampling info for logprob calculation
top_logprobs_nums: Number of logprobs per request.
next_token_ids: Next token ids.
num_tokens_per_req: The number of tokens per request.
Returns:
A list of next_token_ids
"""
self._preprocess_logits(logits_output, sampling_info)
# We should repeat top_logprobs_nums to match num_tokens_per_req.
top_logprobs_nums_repeat_interleaved = []
token_ids_logprobs_repeat_interleaved = []
for num, num_tokens in zip(top_logprobs_nums, num_tokens_per_req):
top_logprobs_nums_repeat_interleaved.extend([num] * num_tokens)
for token_ids, num_tokens in zip(token_ids_logprobs, num_tokens_per_req):
token_ids_logprobs_repeat_interleaved.extend([token_ids] * num_tokens)
self.sampler(
logits_output,
sampling_info,
True,
top_logprobs_nums_repeat_interleaved,
token_ids_logprobs_repeat_interleaved,
batch_next_token_ids=next_token_ids,
)
def sample(
self,
logits_output: LogitsProcessorOutput,