feat: add original logprobs to response (#8375)
Co-authored-by: Chayenne <zhaochen20@outlook.com> Co-authored-by: luhongyu.4869 <luhongyu.4869@bytedance.com>
This commit is contained in:
@@ -46,6 +46,7 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.srt.utils import (
|
||||
empty_context,
|
||||
get_available_gpu_memory,
|
||||
get_bool_env_var,
|
||||
is_cuda,
|
||||
next_power_of_2,
|
||||
)
|
||||
@@ -54,6 +55,7 @@ if is_cuda():
|
||||
from sgl_kernel import segment_packbits
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
RETURN_ORIGINAL_LOGPROB = get_bool_env_var("RETURN_ORIGINAL_LOGPROB")
|
||||
|
||||
|
||||
@contextmanager
|
||||
@@ -788,15 +790,20 @@ class EAGLEWorker(TpModelWorker):
|
||||
token_ids_logprobs = batch.token_ids_logprobs
|
||||
accepted_indices = res.accepted_indices
|
||||
assert len(accepted_indices) == len(logits_output.next_token_logits)
|
||||
|
||||
temperatures = batch.sampling_info.temperatures
|
||||
num_draft_tokens = batch.spec_info.draft_token_num
|
||||
# acceptance indices are the indices in a "flattened" batch.
|
||||
# dividing it to num_draft_tokens will yield the actual batch index.
|
||||
temperatures = temperatures[accepted_indices // num_draft_tokens]
|
||||
|
||||
logprobs = torch.nn.functional.log_softmax(
|
||||
logits_output.next_token_logits / temperatures, dim=-1
|
||||
)
|
||||
if RETURN_ORIGINAL_LOGPROB:
|
||||
logprobs = torch.nn.functional.log_softmax(
|
||||
logits_output.next_token_logits, dim=-1
|
||||
)
|
||||
else:
|
||||
logprobs = torch.nn.functional.log_softmax(
|
||||
logits_output.next_token_logits / temperatures, dim=-1
|
||||
)
|
||||
batch_next_token_ids = res.verified_id
|
||||
num_tokens_per_req = [accept + 1 for accept in res.accept_length_per_req_cpu]
|
||||
|
||||
@@ -813,13 +820,19 @@ class EAGLEWorker(TpModelWorker):
|
||||
(
|
||||
logits_output.next_token_top_logprobs_val,
|
||||
logits_output.next_token_top_logprobs_idx,
|
||||
) = get_top_logprobs(logprobs, top_logprobs_nums_repeat_interleaved)
|
||||
) = get_top_logprobs(
|
||||
logprobs,
|
||||
top_logprobs_nums_repeat_interleaved,
|
||||
)
|
||||
|
||||
if any(x is not None for x in token_ids_logprobs):
|
||||
(
|
||||
logits_output.next_token_token_ids_logprobs_val,
|
||||
logits_output.next_token_token_ids_logprobs_idx,
|
||||
) = get_token_ids_logprobs(logprobs, token_ids_logprobs_repeat_interleaved)
|
||||
) = get_token_ids_logprobs(
|
||||
logprobs,
|
||||
token_ids_logprobs_repeat_interleaved,
|
||||
)
|
||||
|
||||
logits_output.next_token_logprobs = logprobs[
|
||||
torch.arange(len(batch_next_token_ids), device=batch.sampling_info.device),
|
||||
|
||||
Reference in New Issue
Block a user