feat: add original logprobs to response (#8375)

Co-authored-by: Chayenne <zhaochen20@outlook.com>
Co-authored-by: luhongyu.4869 <luhongyu.4869@bytedance.com>
This commit is contained in:
narutolhy
2025-08-29 11:43:57 -07:00
committed by GitHub
parent f1e9bbaff5
commit 839c93bd2d
5 changed files with 246 additions and 12 deletions

View File

@@ -61,7 +61,7 @@ class LogitsProcessorOutput:
hidden_states: Optional[torch.Tensor] = None
## Part 2: This part will be assigned in python/sglang/srt/layers/sampler.py::Sampler
# The logprobs of the next tokens. shape: [#seq]
# he log probs of output tokens, if RETURN_ORIGINAL_LOGPROB = True, will get the log probs before applying temperature. If False, will get the log probs before applying temperature.
next_token_logprobs: Optional[torch.Tensor] = None
# The logprobs and ids of the top-k tokens in output positions. shape: [#seq, k]
next_token_top_logprobs_val: Optional[List] = None

View File

@@ -27,6 +27,7 @@ if is_cuda():
logger = logging.getLogger(__name__)
SYNC_TOKEN_IDS_ACROSS_TP = get_bool_env_var("SYNC_TOKEN_IDS_ACROSS_TP")
RETURN_ORIGINAL_LOGPROB = get_bool_env_var("RETURN_ORIGINAL_LOGPROB")
class Sampler(nn.Module):
@@ -77,7 +78,12 @@ class Sampler(nn.Module):
batch_next_token_ids = torch.argmax(logits, -1)
if return_logprob:
logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
else:
# Post process original logits. if temperatures are all 1.0, no need to rescale
if return_logprob and RETURN_ORIGINAL_LOGPROB:
logprobs = torch.softmax(logits, dim=-1)
# Post process logits
logits.div_(sampling_info.temperatures)
logits[:] = torch.softmax(logits, dim=-1)
@@ -116,7 +122,12 @@ class Sampler(nn.Module):
if return_logprob:
# clamp to avoid -inf
logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min)
if RETURN_ORIGINAL_LOGPROB:
logprobs = torch.log(logprobs).clamp(
min=torch.finfo(logprobs.dtype).min
)
else:
logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min)
# Attach logprobs to logits_output (in-place modification)
if return_logprob:
@@ -201,7 +212,10 @@ def top_p_normalize_probs_torch(
return torch.zeros_like(probs_sort).scatter_(-1, probs_idx, probs_sort)
def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]):
def get_top_logprobs(
logprobs: torch.Tensor,
top_logprobs_nums: List[int],
):
max_k = max(top_logprobs_nums)
ret = logprobs.topk(max_k, dim=1)
values = ret.values.tolist()
@@ -212,10 +226,17 @@ def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]):
for i, k in enumerate(top_logprobs_nums):
output_top_logprobs_val.append(values[i][:k])
output_top_logprobs_idx.append(indices[i][:k])
return output_top_logprobs_val, output_top_logprobs_idx
return (
output_top_logprobs_val,
output_top_logprobs_idx,
)
def get_token_ids_logprobs(logprobs: torch.Tensor, token_ids_logprobs: List[List[int]]):
def get_token_ids_logprobs(
logprobs: torch.Tensor,
token_ids_logprobs: List[List[int]],
):
output_token_ids_logprobs_val = []
output_token_ids_logprobs_idx = []
for i, token_ids in enumerate(token_ids_logprobs):
@@ -226,7 +247,10 @@ def get_token_ids_logprobs(logprobs: torch.Tensor, token_ids_logprobs: List[List
output_token_ids_logprobs_val.append([])
output_token_ids_logprobs_idx.append([])
return output_token_ids_logprobs_val, output_token_ids_logprobs_idx
return (
output_token_ids_logprobs_val,
output_token_ids_logprobs_idx,
)
def apply_custom_logit_processor(

View File

@@ -46,6 +46,7 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import (
empty_context,
get_available_gpu_memory,
get_bool_env_var,
is_cuda,
next_power_of_2,
)
@@ -54,6 +55,7 @@ if is_cuda():
from sgl_kernel import segment_packbits
logger = logging.getLogger(__name__)
RETURN_ORIGINAL_LOGPROB = get_bool_env_var("RETURN_ORIGINAL_LOGPROB")
@contextmanager
@@ -788,15 +790,20 @@ class EAGLEWorker(TpModelWorker):
token_ids_logprobs = batch.token_ids_logprobs
accepted_indices = res.accepted_indices
assert len(accepted_indices) == len(logits_output.next_token_logits)
temperatures = batch.sampling_info.temperatures
num_draft_tokens = batch.spec_info.draft_token_num
# acceptance indices are the indices in a "flattened" batch.
# dividing it to num_draft_tokens will yield the actual batch index.
temperatures = temperatures[accepted_indices // num_draft_tokens]
logprobs = torch.nn.functional.log_softmax(
logits_output.next_token_logits / temperatures, dim=-1
)
if RETURN_ORIGINAL_LOGPROB:
logprobs = torch.nn.functional.log_softmax(
logits_output.next_token_logits, dim=-1
)
else:
logprobs = torch.nn.functional.log_softmax(
logits_output.next_token_logits / temperatures, dim=-1
)
batch_next_token_ids = res.verified_id
num_tokens_per_req = [accept + 1 for accept in res.accept_length_per_req_cpu]
@@ -813,13 +820,19 @@ class EAGLEWorker(TpModelWorker):
(
logits_output.next_token_top_logprobs_val,
logits_output.next_token_top_logprobs_idx,
) = get_top_logprobs(logprobs, top_logprobs_nums_repeat_interleaved)
) = get_top_logprobs(
logprobs,
top_logprobs_nums_repeat_interleaved,
)
if any(x is not None for x in token_ids_logprobs):
(
logits_output.next_token_token_ids_logprobs_val,
logits_output.next_token_token_ids_logprobs_idx,
) = get_token_ids_logprobs(logprobs, token_ids_logprobs_repeat_interleaved)
) = get_token_ids_logprobs(
logprobs,
token_ids_logprobs_repeat_interleaved,
)
logits_output.next_token_logprobs = logprobs[
torch.arange(len(batch_next_token_ids), device=batch.sampling_info.device),

View File

@@ -87,6 +87,7 @@ suites = {
TestFile("test_mla_fp8.py", 93),
TestFile("test_no_chunked_prefill.py", 108),
TestFile("test_no_overlap_scheduler.py", 234),
TestFile("test_original_logprobs.py", 200),
TestFile("test_penalty.py", 41),
TestFile("test_page_size.py", 60),
TestFile("test_pytorch_sampling_backend.py", 66),

View File

@@ -0,0 +1,196 @@
"""Test original log probability alignment between SGLang and Hugging Face.
This test suite verifies the correctness of the `origin_logprobs` output (temperature=1)
and the `logprobs` output (temperature=0.5) in SGLang by comparing it against
raw logit-based probabilities computed directly from a reference Hugging Face model.
The test covers the following scenarios:
- Next-token prediction: Verifies that the log probability of the next token from
SGLang matches the Hugging Face model.
- Top-k logprobs: Ensures that the top-k original logprobs returned by SGLang are
consistent with Hugging Face outputs.
- Specified token IDs: Confirms that the original logprobs for specific token IDs
match the values computed from Hugging Face logits.
"""
import os
import random
import unittest
import numpy as np
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
import sglang as sgl
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
# ------------------------- Configurable via env ------------------------- #
MODEL_ID = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
PROMPTS = [
"Hello, my name is",
"The future of AI is",
"The president of the United States is",
"The capital of France is ",
]
TOP_LOGPROBS_NUM = 50
NUM_RANDOM_TOKEN_IDS = 10
RTOL = 0.20
ATOL = 0.00
# ------------------------------------------------
torch.manual_seed(1234)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(1234)
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
class TestOriginalLogprob(unittest.TestCase):
def setUp(self):
# ----- HF side (float32 weights) -----
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, padding_side="right")
self.hf_model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, torch_dtype=torch.float32, device_map="auto"
)
# Shared sampling parameters
self.sampling_params = {
"temperature": 0.5, # SGLang uses 0.5, but original logprobs are used 1.0
"top_p": 1.0,
"top_k": 10,
"max_new_tokens": 1,
}
# ---------------------------------------------------------------------
# Helper: compare one SGLang block (token_logprobs / top_logprobs / ids_logprobs)
# against a reference HF logprob vector.
# ---------------------------------------------------------------------
def assert_logprobs_block_equal(
self,
hf_log_probs: torch.Tensor, # [V]
token_log_probs: list,
top_log_probs: list,
ids_log_probs: list,
random_token_ids: list,
tag: str = "",
):
vals, idxs, _ = zip(*token_log_probs)
sgl_vals = torch.tensor(vals, device=self.hf_model.device, dtype=torch.float32)
sgl_idxs = torch.tensor(idxs, device=self.hf_model.device, dtype=torch.long)
hf_vals = hf_log_probs[sgl_idxs]
self.assertTrue(
torch.allclose(hf_vals, sgl_vals, rtol=RTOL, atol=ATOL),
msg=f"[{tag}] tokenlevel mismatch at indices {sgl_idxs.tolist()}",
)
hf_topk, _ = torch.topk(hf_log_probs, k=TOP_LOGPROBS_NUM, dim=-1)
sgl_topk = torch.tensor(
[float(t[0]) for t in top_log_probs[0] if t and t[0] is not None][
:TOP_LOGPROBS_NUM
],
dtype=torch.float32,
device=self.hf_model.device,
)
k = min(hf_topk.numel(), sgl_topk.numel())
self.assertTrue(
torch.allclose(hf_topk[:k], sgl_topk[:k], rtol=RTOL, atol=ATOL),
msg=f"[{tag}] topk mismatch",
)
indices = torch.tensor(
random_token_ids, dtype=torch.long, device=hf_log_probs.device
)
hf_token_ids = hf_log_probs[indices]
sgl_token_ids = torch.tensor(
[v for v, _, _ in ids_log_probs[0]],
device=self.hf_model.device,
dtype=torch.float32,
)
self.assertTrue(
torch.allclose(hf_token_ids, sgl_token_ids, rtol=RTOL, atol=ATOL),
msg=f"[{tag}] tokenIDs mismatch",
)
# Optional: print max abs diff for quick diagnostics
max_diff = torch.max(torch.abs(hf_vals - sgl_vals)).item()
print(f"[{tag}] max|diff| tokenlevel = {max_diff:.4f}")
def test_logprob_match(self):
vocab_size = self.tokenizer.vocab_size
for env_val in ["True", "False"]:
with self.subTest(return_original_logprob=env_val):
os.environ["RETURN_ORIGINAL_LOGPROB"] = env_val
# ----- SGLang side -----
sgl_engine = sgl.Engine(
model_path=MODEL_ID,
skip_tokenizer_init=True,
trust_remote_code=True,
mem_fraction_static=0.60,
)
for prompt in PROMPTS:
random_token_ids = sorted(
random.sample(range(vocab_size), NUM_RANDOM_TOKEN_IDS)
)
enc = self.tokenizer(prompt, return_tensors="pt")
input_ids = enc["input_ids"].to(self.hf_model.device)
attn_mask = enc["attention_mask"].to(self.hf_model.device)
with torch.inference_mode():
hf_out = self.hf_model(
input_ids=input_ids,
attention_mask=attn_mask,
return_dict=True,
)
logits = hf_out.logits[:, -1, :] # [1, V]
hf_log_probs = F.log_softmax(
logits.float() / self.sampling_params["temperature"], dim=-1
)[0]
hf_original_log_probs = F.log_softmax(logits.float(), dim=-1)[0]
outputs = sgl_engine.generate(
input_ids=input_ids[0].tolist(),
sampling_params=self.sampling_params,
return_logprob=True,
top_logprobs_num=TOP_LOGPROBS_NUM,
token_ids_logprob=random_token_ids,
)
if isinstance(outputs, list):
outputs = outputs[0]
meta = outputs["meta_info"]
# Check original logprobs only if enabled
if env_val.lower() == "true":
self.assert_logprobs_block_equal(
hf_log_probs=hf_original_log_probs,
token_log_probs=meta["output_token_logprobs"],
top_log_probs=meta["output_top_logprobs"],
ids_log_probs=meta["output_token_ids_logprobs"],
random_token_ids=random_token_ids,
tag=f"Original logprobs SGLang vs HF: {prompt} ({env_val})",
)
else:
# Always check regular logprobs
self.assert_logprobs_block_equal(
hf_log_probs=hf_log_probs,
token_log_probs=meta["output_token_logprobs"],
top_log_probs=meta["output_top_logprobs"],
ids_log_probs=meta["output_token_ids_logprobs"],
random_token_ids=random_token_ids,
tag=f"logprobs SGLang vs HF: {prompt} ({env_val})",
)
sgl_engine.shutdown()
if __name__ == "__main__":
unittest.main()