[EAGLE] many fixes for eagle (#4195)
Co-authored-by: SangBin Cho <rkooo567@gmail.com> Co-authored-by: Sehoon Kim <sehoon@x.ai>
This commit is contained in:
@@ -7,6 +7,7 @@ import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
@@ -302,13 +303,10 @@ class EAGLEWorker(TpModelWorker):
|
||||
|
||||
# Set inputs
|
||||
forward_batch.input_ids = input_ids
|
||||
out_cache_loc = out_cache_loc.view(forward_batch.batch_size, -1)
|
||||
forward_batch.out_cache_loc = out_cache_loc[
|
||||
forward_batch.batch_size
|
||||
* self.topk
|
||||
* i : forward_batch.batch_size
|
||||
* self.topk
|
||||
* (i + 1)
|
||||
]
|
||||
:, self.topk * i : self.topk * (i + 1)
|
||||
].flatten()
|
||||
forward_batch.positions.add_(1)
|
||||
forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i]
|
||||
spec_info.hidden_states = hidden_states
|
||||
@@ -353,42 +351,70 @@ class EAGLEWorker(TpModelWorker):
|
||||
batch.spec_info = res.draft_input
|
||||
|
||||
if batch.return_logprob:
|
||||
# Compute output logprobs using the sampler.
|
||||
num_tokens_per_req = [
|
||||
accept + 1 for accept in res.accept_length_per_req_cpu
|
||||
]
|
||||
self.target_worker.model_runner.update_output_logprobs(
|
||||
logits_output,
|
||||
batch.sampling_info,
|
||||
batch.top_logprobs_nums,
|
||||
batch.token_ids_logprobs,
|
||||
res.verified_id,
|
||||
# +1 for bonus token.
|
||||
num_tokens_per_req=num_tokens_per_req,
|
||||
)
|
||||
|
||||
# Add output logprobs to the request.
|
||||
pt = 0
|
||||
# NOTE: tolist() of these values are skipped when output is processed
|
||||
next_token_logprobs = res.logits_output.next_token_logprobs.tolist()
|
||||
verified_ids = res.verified_id.tolist()
|
||||
for req, num_tokens in zip(batch.reqs, num_tokens_per_req):
|
||||
for _ in range(num_tokens):
|
||||
if req.return_logprob:
|
||||
token_id = verified_ids[pt]
|
||||
req.output_token_logprobs_val.append(next_token_logprobs[pt])
|
||||
req.output_token_logprobs_idx.append(token_id)
|
||||
if req.top_logprobs_num > 0:
|
||||
req.output_top_logprobs_val.append(
|
||||
res.logits_output.next_token_top_logprobs_val[pt]
|
||||
)
|
||||
req.output_top_logprobs_idx.append(
|
||||
res.logits_output.next_token_top_logprobs_idx[pt]
|
||||
)
|
||||
pt += 1
|
||||
self.add_logprob_values(batch, res, logits_output)
|
||||
|
||||
return logits_output, res, model_worker_batch
|
||||
|
||||
def add_logprob_values(
|
||||
self,
|
||||
batch: ScheduleBatch,
|
||||
res: EagleVerifyOutput,
|
||||
logits_output: LogitsProcessorOutput,
|
||||
):
|
||||
# Extract args
|
||||
logits_output = res.logits_output
|
||||
top_logprobs_nums = batch.top_logprobs_nums
|
||||
token_ids_logprobs = batch.token_ids_logprobs
|
||||
logprobs = torch.nn.functional.log_softmax(
|
||||
logits_output.next_token_logits, dim=-1
|
||||
)
|
||||
batch_next_token_ids = res.verified_id
|
||||
num_tokens_per_req = [accept + 1 for accept in res.accept_length_per_req_cpu]
|
||||
|
||||
# We should repeat top_logprobs_nums to match num_tokens_per_req.
|
||||
top_logprobs_nums_repeat_interleaved = []
|
||||
token_ids_logprobs_repeat_interleaved = []
|
||||
for num, num_tokens in zip(top_logprobs_nums, num_tokens_per_req):
|
||||
top_logprobs_nums_repeat_interleaved.extend([num] * num_tokens)
|
||||
for token_ids, num_tokens in zip(token_ids_logprobs, num_tokens_per_req):
|
||||
token_ids_logprobs_repeat_interleaved.extend([token_ids] * num_tokens)
|
||||
|
||||
# Extract logprobs
|
||||
if any(x > 0 for x in top_logprobs_nums):
|
||||
(
|
||||
logits_output.next_token_top_logprobs_val,
|
||||
logits_output.next_token_top_logprobs_idx,
|
||||
) = get_top_logprobs(logprobs, top_logprobs_nums_repeat_interleaved)
|
||||
|
||||
if any(x is not None for x in token_ids_logprobs):
|
||||
(
|
||||
logits_output.next_token_token_ids_logprobs_val,
|
||||
logits_output.next_token_token_ids_logprobs_idx,
|
||||
) = get_token_ids_logprobs(logprobs, token_ids_logprobs_repeat_interleaved)
|
||||
|
||||
logits_output.next_token_logprobs = logprobs[
|
||||
torch.arange(len(batch_next_token_ids), device=batch.sampling_info.device),
|
||||
batch_next_token_ids,
|
||||
]
|
||||
|
||||
# Add output logprobs to the request.
|
||||
pt = 0
|
||||
next_token_logprobs = logits_output.next_token_logprobs.tolist()
|
||||
verified_ids = batch_next_token_ids.tolist()
|
||||
for req, num_tokens in zip(batch.reqs, num_tokens_per_req):
|
||||
for _ in range(num_tokens):
|
||||
if req.return_logprob:
|
||||
req.output_token_logprobs_val.append(next_token_logprobs[pt])
|
||||
req.output_token_logprobs_idx.append(verified_ids[pt])
|
||||
if req.top_logprobs_num > 0:
|
||||
req.output_top_logprobs_val.append(
|
||||
res.logits_output.next_token_top_logprobs_val[pt]
|
||||
)
|
||||
req.output_top_logprobs_idx.append(
|
||||
res.logits_output.next_token_top_logprobs_idx[pt]
|
||||
)
|
||||
pt += 1
|
||||
|
||||
def forward_draft_extend(
|
||||
self,
|
||||
batch: ScheduleBatch,
|
||||
|
||||
Reference in New Issue
Block a user