feat: add original logprobs to response (#8375)

Co-authored-by: Chayenne <zhaochen20@outlook.com>
Co-authored-by: luhongyu.4869 <luhongyu.4869@bytedance.com>
This commit is contained in:
narutolhy
2025-08-29 11:43:57 -07:00
committed by GitHub
parent f1e9bbaff5
commit 839c93bd2d
5 changed files with 246 additions and 12 deletions

View File

@@ -61,7 +61,7 @@ class LogitsProcessorOutput:
hidden_states: Optional[torch.Tensor] = None
## Part 2: This part will be assigned in python/sglang/srt/layers/sampler.py::Sampler
# The logprobs of the next tokens. shape: [#seq]
# he log probs of output tokens, if RETURN_ORIGINAL_LOGPROB = True, will get the log probs before applying temperature. If False, will get the log probs before applying temperature.
next_token_logprobs: Optional[torch.Tensor] = None
# The logprobs and ids of the top-k tokens in output positions. shape: [#seq, k]
next_token_top_logprobs_val: Optional[List] = None

View File

@@ -27,6 +27,7 @@ if is_cuda():
logger = logging.getLogger(__name__)
SYNC_TOKEN_IDS_ACROSS_TP = get_bool_env_var("SYNC_TOKEN_IDS_ACROSS_TP")
RETURN_ORIGINAL_LOGPROB = get_bool_env_var("RETURN_ORIGINAL_LOGPROB")
class Sampler(nn.Module):
@@ -77,7 +78,12 @@ class Sampler(nn.Module):
batch_next_token_ids = torch.argmax(logits, -1)
if return_logprob:
logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
else:
# Post process original logits. if temperatures are all 1.0, no need to rescale
if return_logprob and RETURN_ORIGINAL_LOGPROB:
logprobs = torch.softmax(logits, dim=-1)
# Post process logits
logits.div_(sampling_info.temperatures)
logits[:] = torch.softmax(logits, dim=-1)
@@ -116,7 +122,12 @@ class Sampler(nn.Module):
if return_logprob:
# clamp to avoid -inf
logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min)
if RETURN_ORIGINAL_LOGPROB:
logprobs = torch.log(logprobs).clamp(
min=torch.finfo(logprobs.dtype).min
)
else:
logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min)
# Attach logprobs to logits_output (in-place modification)
if return_logprob:
@@ -201,7 +212,10 @@ def top_p_normalize_probs_torch(
return torch.zeros_like(probs_sort).scatter_(-1, probs_idx, probs_sort)
def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]):
def get_top_logprobs(
logprobs: torch.Tensor,
top_logprobs_nums: List[int],
):
max_k = max(top_logprobs_nums)
ret = logprobs.topk(max_k, dim=1)
values = ret.values.tolist()
@@ -212,10 +226,17 @@ def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]):
for i, k in enumerate(top_logprobs_nums):
output_top_logprobs_val.append(values[i][:k])
output_top_logprobs_idx.append(indices[i][:k])
return output_top_logprobs_val, output_top_logprobs_idx
return (
output_top_logprobs_val,
output_top_logprobs_idx,
)
def get_token_ids_logprobs(logprobs: torch.Tensor, token_ids_logprobs: List[List[int]]):
def get_token_ids_logprobs(
logprobs: torch.Tensor,
token_ids_logprobs: List[List[int]],
):
output_token_ids_logprobs_val = []
output_token_ids_logprobs_idx = []
for i, token_ids in enumerate(token_ids_logprobs):
@@ -226,7 +247,10 @@ def get_token_ids_logprobs(logprobs: torch.Tensor, token_ids_logprobs: List[List
output_token_ids_logprobs_val.append([])
output_token_ids_logprobs_idx.append([])
return output_token_ids_logprobs_val, output_token_ids_logprobs_idx
return (
output_token_ids_logprobs_val,
output_token_ids_logprobs_idx,
)
def apply_custom_logit_processor(

View File

@@ -46,6 +46,7 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import (
empty_context,
get_available_gpu_memory,
get_bool_env_var,
is_cuda,
next_power_of_2,
)
@@ -54,6 +55,7 @@ if is_cuda():
from sgl_kernel import segment_packbits
logger = logging.getLogger(__name__)
RETURN_ORIGINAL_LOGPROB = get_bool_env_var("RETURN_ORIGINAL_LOGPROB")
@contextmanager
@@ -788,15 +790,20 @@ class EAGLEWorker(TpModelWorker):
token_ids_logprobs = batch.token_ids_logprobs
accepted_indices = res.accepted_indices
assert len(accepted_indices) == len(logits_output.next_token_logits)
temperatures = batch.sampling_info.temperatures
num_draft_tokens = batch.spec_info.draft_token_num
# acceptance indices are the indices in a "flattened" batch.
# dividing it to num_draft_tokens will yield the actual batch index.
temperatures = temperatures[accepted_indices // num_draft_tokens]
logprobs = torch.nn.functional.log_softmax(
logits_output.next_token_logits / temperatures, dim=-1
)
if RETURN_ORIGINAL_LOGPROB:
logprobs = torch.nn.functional.log_softmax(
logits_output.next_token_logits, dim=-1
)
else:
logprobs = torch.nn.functional.log_softmax(
logits_output.next_token_logits / temperatures, dim=-1
)
batch_next_token_ids = res.verified_id
num_tokens_per_req = [accept + 1 for accept in res.accept_length_per_req_cpu]
@@ -813,13 +820,19 @@ class EAGLEWorker(TpModelWorker):
(
logits_output.next_token_top_logprobs_val,
logits_output.next_token_top_logprobs_idx,
) = get_top_logprobs(logprobs, top_logprobs_nums_repeat_interleaved)
) = get_top_logprobs(
logprobs,
top_logprobs_nums_repeat_interleaved,
)
if any(x is not None for x in token_ids_logprobs):
(
logits_output.next_token_token_ids_logprobs_val,
logits_output.next_token_token_ids_logprobs_idx,
) = get_token_ids_logprobs(logprobs, token_ids_logprobs_repeat_interleaved)
) = get_token_ids_logprobs(
logprobs,
token_ids_logprobs_repeat_interleaved,
)
logits_output.next_token_logprobs = logprobs[
torch.arange(len(batch_next_token_ids), device=batch.sampling_info.device),