[EAGLE] many fixes for eagle (#4195)
Co-authored-by: SangBin Cho <rkooo567@gmail.com> Co-authored-by: Sehoon Kim <sehoon@x.ai>
This commit is contained in:
@@ -928,45 +928,6 @@ class ModelRunner:
|
||||
sampling_info.update_regex_vocab_mask()
|
||||
sampling_info.apply_logits_bias(logits_output.next_token_logits)
|
||||
|
||||
def update_output_logprobs(
|
||||
self,
|
||||
logits_output: LogitsProcessorOutput,
|
||||
sampling_info: SamplingBatchInfo,
|
||||
top_logprobs_nums: List[int],
|
||||
token_ids_logprobs: List[int],
|
||||
next_token_ids: torch.Tensor,
|
||||
*,
|
||||
num_tokens_per_req: List[int],
|
||||
):
|
||||
"""Update the logits_output's output logprob based on next_token_ids
|
||||
|
||||
Args:
|
||||
logits_output: The logits output from the model forward
|
||||
sampling_info: Sampling info for logprob calculation
|
||||
top_logprobs_nums: Number of logprobs per request.
|
||||
next_token_ids: Next token ids.
|
||||
num_tokens_per_req: The number of tokens per request.
|
||||
|
||||
Returns:
|
||||
A list of next_token_ids
|
||||
"""
|
||||
self._preprocess_logits(logits_output, sampling_info)
|
||||
# We should repeat top_logprobs_nums to match num_tokens_per_req.
|
||||
top_logprobs_nums_repeat_interleaved = []
|
||||
token_ids_logprobs_repeat_interleaved = []
|
||||
for num, num_tokens in zip(top_logprobs_nums, num_tokens_per_req):
|
||||
top_logprobs_nums_repeat_interleaved.extend([num] * num_tokens)
|
||||
for token_ids, num_tokens in zip(token_ids_logprobs, num_tokens_per_req):
|
||||
token_ids_logprobs_repeat_interleaved.extend([token_ids] * num_tokens)
|
||||
self.sampler(
|
||||
logits_output,
|
||||
sampling_info,
|
||||
True,
|
||||
top_logprobs_nums_repeat_interleaved,
|
||||
token_ids_logprobs_repeat_interleaved,
|
||||
batch_next_token_ids=next_token_ids,
|
||||
)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits_output: LogitsProcessorOutput,
|
||||
|
||||
Reference in New Issue
Block a user