feat: add original logprobs to response (#8375)
Co-authored-by: Chayenne <zhaochen20@outlook.com> Co-authored-by: luhongyu.4869 <luhongyu.4869@bytedance.com>
This commit is contained in:
@@ -61,7 +61,7 @@ class LogitsProcessorOutput:
|
||||
hidden_states: Optional[torch.Tensor] = None
|
||||
|
||||
## Part 2: This part will be assigned in python/sglang/srt/layers/sampler.py::Sampler
|
||||
# The logprobs of the next tokens. shape: [#seq]
|
||||
# he log probs of output tokens, if RETURN_ORIGINAL_LOGPROB = True, will get the log probs before applying temperature. If False, will get the log probs before applying temperature.
|
||||
next_token_logprobs: Optional[torch.Tensor] = None
|
||||
# The logprobs and ids of the top-k tokens in output positions. shape: [#seq, k]
|
||||
next_token_top_logprobs_val: Optional[List] = None
|
||||
|
||||
@@ -27,6 +27,7 @@ if is_cuda():
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SYNC_TOKEN_IDS_ACROSS_TP = get_bool_env_var("SYNC_TOKEN_IDS_ACROSS_TP")
|
||||
RETURN_ORIGINAL_LOGPROB = get_bool_env_var("RETURN_ORIGINAL_LOGPROB")
|
||||
|
||||
|
||||
class Sampler(nn.Module):
|
||||
@@ -77,7 +78,12 @@ class Sampler(nn.Module):
|
||||
batch_next_token_ids = torch.argmax(logits, -1)
|
||||
if return_logprob:
|
||||
logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
|
||||
|
||||
else:
|
||||
# Post process original logits. if temperatures are all 1.0, no need to rescale
|
||||
if return_logprob and RETURN_ORIGINAL_LOGPROB:
|
||||
logprobs = torch.softmax(logits, dim=-1)
|
||||
|
||||
# Post process logits
|
||||
logits.div_(sampling_info.temperatures)
|
||||
logits[:] = torch.softmax(logits, dim=-1)
|
||||
@@ -116,7 +122,12 @@ class Sampler(nn.Module):
|
||||
|
||||
if return_logprob:
|
||||
# clamp to avoid -inf
|
||||
logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min)
|
||||
if RETURN_ORIGINAL_LOGPROB:
|
||||
logprobs = torch.log(logprobs).clamp(
|
||||
min=torch.finfo(logprobs.dtype).min
|
||||
)
|
||||
else:
|
||||
logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min)
|
||||
|
||||
# Attach logprobs to logits_output (in-place modification)
|
||||
if return_logprob:
|
||||
@@ -201,7 +212,10 @@ def top_p_normalize_probs_torch(
|
||||
return torch.zeros_like(probs_sort).scatter_(-1, probs_idx, probs_sort)
|
||||
|
||||
|
||||
def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]):
|
||||
def get_top_logprobs(
|
||||
logprobs: torch.Tensor,
|
||||
top_logprobs_nums: List[int],
|
||||
):
|
||||
max_k = max(top_logprobs_nums)
|
||||
ret = logprobs.topk(max_k, dim=1)
|
||||
values = ret.values.tolist()
|
||||
@@ -212,10 +226,17 @@ def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]):
|
||||
for i, k in enumerate(top_logprobs_nums):
|
||||
output_top_logprobs_val.append(values[i][:k])
|
||||
output_top_logprobs_idx.append(indices[i][:k])
|
||||
return output_top_logprobs_val, output_top_logprobs_idx
|
||||
|
||||
return (
|
||||
output_top_logprobs_val,
|
||||
output_top_logprobs_idx,
|
||||
)
|
||||
|
||||
|
||||
def get_token_ids_logprobs(logprobs: torch.Tensor, token_ids_logprobs: List[List[int]]):
|
||||
def get_token_ids_logprobs(
|
||||
logprobs: torch.Tensor,
|
||||
token_ids_logprobs: List[List[int]],
|
||||
):
|
||||
output_token_ids_logprobs_val = []
|
||||
output_token_ids_logprobs_idx = []
|
||||
for i, token_ids in enumerate(token_ids_logprobs):
|
||||
@@ -226,7 +247,10 @@ def get_token_ids_logprobs(logprobs: torch.Tensor, token_ids_logprobs: List[List
|
||||
output_token_ids_logprobs_val.append([])
|
||||
output_token_ids_logprobs_idx.append([])
|
||||
|
||||
return output_token_ids_logprobs_val, output_token_ids_logprobs_idx
|
||||
return (
|
||||
output_token_ids_logprobs_val,
|
||||
output_token_ids_logprobs_idx,
|
||||
)
|
||||
|
||||
|
||||
def apply_custom_logit_processor(
|
||||
|
||||
@@ -46,6 +46,7 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.srt.utils import (
|
||||
empty_context,
|
||||
get_available_gpu_memory,
|
||||
get_bool_env_var,
|
||||
is_cuda,
|
||||
next_power_of_2,
|
||||
)
|
||||
@@ -54,6 +55,7 @@ if is_cuda():
|
||||
from sgl_kernel import segment_packbits
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
RETURN_ORIGINAL_LOGPROB = get_bool_env_var("RETURN_ORIGINAL_LOGPROB")
|
||||
|
||||
|
||||
@contextmanager
|
||||
@@ -788,15 +790,20 @@ class EAGLEWorker(TpModelWorker):
|
||||
token_ids_logprobs = batch.token_ids_logprobs
|
||||
accepted_indices = res.accepted_indices
|
||||
assert len(accepted_indices) == len(logits_output.next_token_logits)
|
||||
|
||||
temperatures = batch.sampling_info.temperatures
|
||||
num_draft_tokens = batch.spec_info.draft_token_num
|
||||
# acceptance indices are the indices in a "flattened" batch.
|
||||
# dividing it to num_draft_tokens will yield the actual batch index.
|
||||
temperatures = temperatures[accepted_indices // num_draft_tokens]
|
||||
|
||||
logprobs = torch.nn.functional.log_softmax(
|
||||
logits_output.next_token_logits / temperatures, dim=-1
|
||||
)
|
||||
if RETURN_ORIGINAL_LOGPROB:
|
||||
logprobs = torch.nn.functional.log_softmax(
|
||||
logits_output.next_token_logits, dim=-1
|
||||
)
|
||||
else:
|
||||
logprobs = torch.nn.functional.log_softmax(
|
||||
logits_output.next_token_logits / temperatures, dim=-1
|
||||
)
|
||||
batch_next_token_ids = res.verified_id
|
||||
num_tokens_per_req = [accept + 1 for accept in res.accept_length_per_req_cpu]
|
||||
|
||||
@@ -813,13 +820,19 @@ class EAGLEWorker(TpModelWorker):
|
||||
(
|
||||
logits_output.next_token_top_logprobs_val,
|
||||
logits_output.next_token_top_logprobs_idx,
|
||||
) = get_top_logprobs(logprobs, top_logprobs_nums_repeat_interleaved)
|
||||
) = get_top_logprobs(
|
||||
logprobs,
|
||||
top_logprobs_nums_repeat_interleaved,
|
||||
)
|
||||
|
||||
if any(x is not None for x in token_ids_logprobs):
|
||||
(
|
||||
logits_output.next_token_token_ids_logprobs_val,
|
||||
logits_output.next_token_token_ids_logprobs_idx,
|
||||
) = get_token_ids_logprobs(logprobs, token_ids_logprobs_repeat_interleaved)
|
||||
) = get_token_ids_logprobs(
|
||||
logprobs,
|
||||
token_ids_logprobs_repeat_interleaved,
|
||||
)
|
||||
|
||||
logits_output.next_token_logprobs = logprobs[
|
||||
torch.arange(len(batch_next_token_ids), device=batch.sampling_info.device),
|
||||
|
||||
@@ -87,6 +87,7 @@ suites = {
|
||||
TestFile("test_mla_fp8.py", 93),
|
||||
TestFile("test_no_chunked_prefill.py", 108),
|
||||
TestFile("test_no_overlap_scheduler.py", 234),
|
||||
TestFile("test_original_logprobs.py", 200),
|
||||
TestFile("test_penalty.py", 41),
|
||||
TestFile("test_page_size.py", 60),
|
||||
TestFile("test_pytorch_sampling_backend.py", 66),
|
||||
|
||||
196
test/srt/test_original_logprobs.py
Normal file
196
test/srt/test_original_logprobs.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""Test original log probability alignment between SGLang and Hugging Face.
|
||||
|
||||
This test suite verifies the correctness of the `origin_logprobs` output (temperature=1)
|
||||
and the `logprobs` output (temperature=0.5) in SGLang by comparing it against
|
||||
raw logit-based probabilities computed directly from a reference Hugging Face model.
|
||||
|
||||
The test covers the following scenarios:
|
||||
- Next-token prediction: Verifies that the log probability of the next token from
|
||||
SGLang matches the Hugging Face model.
|
||||
- Top-k logprobs: Ensures that the top-k original logprobs returned by SGLang are
|
||||
consistent with Hugging Face outputs.
|
||||
- Specified token IDs: Confirms that the original logprobs for specific token IDs
|
||||
match the values computed from Hugging Face logits.
|
||||
"""
|
||||
|
||||
import os
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
import sglang as sgl
|
||||
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
|
||||
# ------------------------- Configurable via env ------------------------- #
|
||||
MODEL_ID = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
PROMPTS = [
|
||||
"Hello, my name is",
|
||||
"The future of AI is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is ",
|
||||
]
|
||||
TOP_LOGPROBS_NUM = 50
|
||||
NUM_RANDOM_TOKEN_IDS = 10
|
||||
RTOL = 0.20
|
||||
ATOL = 0.00
|
||||
# ------------------------------------------------
|
||||
|
||||
torch.manual_seed(1234)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(1234)
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
torch.backends.cudnn.allow_tf32 = False
|
||||
|
||||
|
||||
class TestOriginalLogprob(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# ----- HF side (float32 weights) -----
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, padding_side="right")
|
||||
self.hf_model = AutoModelForCausalLM.from_pretrained(
|
||||
MODEL_ID, torch_dtype=torch.float32, device_map="auto"
|
||||
)
|
||||
|
||||
# Shared sampling parameters
|
||||
self.sampling_params = {
|
||||
"temperature": 0.5, # SGLang uses 0.5, but original logprobs are used 1.0
|
||||
"top_p": 1.0,
|
||||
"top_k": 10,
|
||||
"max_new_tokens": 1,
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Helper: compare one SGLang block (token_logprobs / top_logprobs / ids_logprobs)
|
||||
# against a reference HF log‑prob vector.
|
||||
# ---------------------------------------------------------------------
|
||||
def assert_logprobs_block_equal(
|
||||
self,
|
||||
hf_log_probs: torch.Tensor, # [V]
|
||||
token_log_probs: list,
|
||||
top_log_probs: list,
|
||||
ids_log_probs: list,
|
||||
random_token_ids: list,
|
||||
tag: str = "",
|
||||
):
|
||||
vals, idxs, _ = zip(*token_log_probs)
|
||||
sgl_vals = torch.tensor(vals, device=self.hf_model.device, dtype=torch.float32)
|
||||
sgl_idxs = torch.tensor(idxs, device=self.hf_model.device, dtype=torch.long)
|
||||
hf_vals = hf_log_probs[sgl_idxs]
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(hf_vals, sgl_vals, rtol=RTOL, atol=ATOL),
|
||||
msg=f"[{tag}] token‑level mismatch at indices {sgl_idxs.tolist()}",
|
||||
)
|
||||
|
||||
hf_topk, _ = torch.topk(hf_log_probs, k=TOP_LOGPROBS_NUM, dim=-1)
|
||||
|
||||
sgl_topk = torch.tensor(
|
||||
[float(t[0]) for t in top_log_probs[0] if t and t[0] is not None][
|
||||
:TOP_LOGPROBS_NUM
|
||||
],
|
||||
dtype=torch.float32,
|
||||
device=self.hf_model.device,
|
||||
)
|
||||
|
||||
k = min(hf_topk.numel(), sgl_topk.numel())
|
||||
self.assertTrue(
|
||||
torch.allclose(hf_topk[:k], sgl_topk[:k], rtol=RTOL, atol=ATOL),
|
||||
msg=f"[{tag}] top‑k mismatch",
|
||||
)
|
||||
|
||||
indices = torch.tensor(
|
||||
random_token_ids, dtype=torch.long, device=hf_log_probs.device
|
||||
)
|
||||
|
||||
hf_token_ids = hf_log_probs[indices]
|
||||
|
||||
sgl_token_ids = torch.tensor(
|
||||
[v for v, _, _ in ids_log_probs[0]],
|
||||
device=self.hf_model.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(hf_token_ids, sgl_token_ids, rtol=RTOL, atol=ATOL),
|
||||
msg=f"[{tag}] token‑IDs mismatch",
|
||||
)
|
||||
|
||||
# Optional: print max abs diff for quick diagnostics
|
||||
max_diff = torch.max(torch.abs(hf_vals - sgl_vals)).item()
|
||||
print(f"[{tag}] max|diff| token‑level = {max_diff:.4f}")
|
||||
|
||||
def test_logprob_match(self):
|
||||
vocab_size = self.tokenizer.vocab_size
|
||||
|
||||
for env_val in ["True", "False"]:
|
||||
with self.subTest(return_original_logprob=env_val):
|
||||
os.environ["RETURN_ORIGINAL_LOGPROB"] = env_val
|
||||
|
||||
# ----- SGLang side -----
|
||||
sgl_engine = sgl.Engine(
|
||||
model_path=MODEL_ID,
|
||||
skip_tokenizer_init=True,
|
||||
trust_remote_code=True,
|
||||
mem_fraction_static=0.60,
|
||||
)
|
||||
|
||||
for prompt in PROMPTS:
|
||||
random_token_ids = sorted(
|
||||
random.sample(range(vocab_size), NUM_RANDOM_TOKEN_IDS)
|
||||
)
|
||||
|
||||
enc = self.tokenizer(prompt, return_tensors="pt")
|
||||
input_ids = enc["input_ids"].to(self.hf_model.device)
|
||||
attn_mask = enc["attention_mask"].to(self.hf_model.device)
|
||||
|
||||
with torch.inference_mode():
|
||||
hf_out = self.hf_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attn_mask,
|
||||
return_dict=True,
|
||||
)
|
||||
logits = hf_out.logits[:, -1, :] # [1, V]
|
||||
hf_log_probs = F.log_softmax(
|
||||
logits.float() / self.sampling_params["temperature"], dim=-1
|
||||
)[0]
|
||||
hf_original_log_probs = F.log_softmax(logits.float(), dim=-1)[0]
|
||||
|
||||
outputs = sgl_engine.generate(
|
||||
input_ids=input_ids[0].tolist(),
|
||||
sampling_params=self.sampling_params,
|
||||
return_logprob=True,
|
||||
top_logprobs_num=TOP_LOGPROBS_NUM,
|
||||
token_ids_logprob=random_token_ids,
|
||||
)
|
||||
|
||||
if isinstance(outputs, list):
|
||||
outputs = outputs[0]
|
||||
meta = outputs["meta_info"]
|
||||
|
||||
# Check original logprobs only if enabled
|
||||
if env_val.lower() == "true":
|
||||
self.assert_logprobs_block_equal(
|
||||
hf_log_probs=hf_original_log_probs,
|
||||
token_log_probs=meta["output_token_logprobs"],
|
||||
top_log_probs=meta["output_top_logprobs"],
|
||||
ids_log_probs=meta["output_token_ids_logprobs"],
|
||||
random_token_ids=random_token_ids,
|
||||
tag=f"Original logprobs SGLang vs HF: {prompt} ({env_val})",
|
||||
)
|
||||
else:
|
||||
# Always check regular logprobs
|
||||
self.assert_logprobs_block_equal(
|
||||
hf_log_probs=hf_log_probs,
|
||||
token_log_probs=meta["output_token_logprobs"],
|
||||
top_log_probs=meta["output_top_logprobs"],
|
||||
ids_log_probs=meta["output_token_ids_logprobs"],
|
||||
random_token_ids=random_token_ids,
|
||||
tag=f"logprobs SGLang vs HF: {prompt} ({env_val})",
|
||||
)
|
||||
sgl_engine.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user