From 839c93bd2d141f0064ded6d828eb7a479629edc9 Mon Sep 17 00:00:00 2001 From: narutolhy <582909902@qq.com> Date: Fri, 29 Aug 2025 11:43:57 -0700 Subject: [PATCH] feat: add original logprobs to response (#8375) Co-authored-by: Chayenne Co-authored-by: luhongyu.4869 --- python/sglang/srt/layers/logits_processor.py | 2 +- python/sglang/srt/layers/sampler.py | 34 ++- python/sglang/srt/speculative/eagle_worker.py | 25 ++- test/srt/run_suite.py | 1 + test/srt/test_original_logprobs.py | 196 ++++++++++++++++++ 5 files changed, 246 insertions(+), 12 deletions(-) create mode 100644 test/srt/test_original_logprobs.py diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 00b30a848..a4fb29929 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -61,7 +61,7 @@ class LogitsProcessorOutput: hidden_states: Optional[torch.Tensor] = None ## Part 2: This part will be assigned in python/sglang/srt/layers/sampler.py::Sampler - # The logprobs of the next tokens. shape: [#seq] + # he log probs of output tokens, if RETURN_ORIGINAL_LOGPROB = True, will get the log probs before applying temperature. If False, will get the log probs before applying temperature. next_token_logprobs: Optional[torch.Tensor] = None # The logprobs and ids of the top-k tokens in output positions. shape: [#seq, k] next_token_top_logprobs_val: Optional[List] = None diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index cf4222cc7..56a831f2d 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -27,6 +27,7 @@ if is_cuda(): logger = logging.getLogger(__name__) SYNC_TOKEN_IDS_ACROSS_TP = get_bool_env_var("SYNC_TOKEN_IDS_ACROSS_TP") +RETURN_ORIGINAL_LOGPROB = get_bool_env_var("RETURN_ORIGINAL_LOGPROB") class Sampler(nn.Module): @@ -77,7 +78,12 @@ class Sampler(nn.Module): batch_next_token_ids = torch.argmax(logits, -1) if return_logprob: logprobs = torch.nn.functional.log_softmax(logits, dim=-1) + else: + # Post process original logits. if temperatures are all 1.0, no need to rescale + if return_logprob and RETURN_ORIGINAL_LOGPROB: + logprobs = torch.softmax(logits, dim=-1) + # Post process logits logits.div_(sampling_info.temperatures) logits[:] = torch.softmax(logits, dim=-1) @@ -116,7 +122,12 @@ class Sampler(nn.Module): if return_logprob: # clamp to avoid -inf - logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min) + if RETURN_ORIGINAL_LOGPROB: + logprobs = torch.log(logprobs).clamp( + min=torch.finfo(logprobs.dtype).min + ) + else: + logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min) # Attach logprobs to logits_output (in-place modification) if return_logprob: @@ -201,7 +212,10 @@ def top_p_normalize_probs_torch( return torch.zeros_like(probs_sort).scatter_(-1, probs_idx, probs_sort) -def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]): +def get_top_logprobs( + logprobs: torch.Tensor, + top_logprobs_nums: List[int], +): max_k = max(top_logprobs_nums) ret = logprobs.topk(max_k, dim=1) values = ret.values.tolist() @@ -212,10 +226,17 @@ def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]): for i, k in enumerate(top_logprobs_nums): output_top_logprobs_val.append(values[i][:k]) output_top_logprobs_idx.append(indices[i][:k]) - return output_top_logprobs_val, output_top_logprobs_idx + + return ( + output_top_logprobs_val, + output_top_logprobs_idx, + ) -def get_token_ids_logprobs(logprobs: torch.Tensor, token_ids_logprobs: List[List[int]]): +def get_token_ids_logprobs( + logprobs: torch.Tensor, + token_ids_logprobs: List[List[int]], +): output_token_ids_logprobs_val = [] output_token_ids_logprobs_idx = [] for i, token_ids in enumerate(token_ids_logprobs): @@ -226,7 +247,10 @@ def get_token_ids_logprobs(logprobs: torch.Tensor, token_ids_logprobs: List[List output_token_ids_logprobs_val.append([]) output_token_ids_logprobs_idx.append([]) - return output_token_ids_logprobs_val, output_token_ids_logprobs_idx + return ( + output_token_ids_logprobs_val, + output_token_ids_logprobs_idx, + ) def apply_custom_logit_processor( diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 5a9454cd2..24e3eca95 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -46,6 +46,7 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.utils import ( empty_context, get_available_gpu_memory, + get_bool_env_var, is_cuda, next_power_of_2, ) @@ -54,6 +55,7 @@ if is_cuda(): from sgl_kernel import segment_packbits logger = logging.getLogger(__name__) +RETURN_ORIGINAL_LOGPROB = get_bool_env_var("RETURN_ORIGINAL_LOGPROB") @contextmanager @@ -788,15 +790,20 @@ class EAGLEWorker(TpModelWorker): token_ids_logprobs = batch.token_ids_logprobs accepted_indices = res.accepted_indices assert len(accepted_indices) == len(logits_output.next_token_logits) + temperatures = batch.sampling_info.temperatures num_draft_tokens = batch.spec_info.draft_token_num # acceptance indices are the indices in a "flattened" batch. # dividing it to num_draft_tokens will yield the actual batch index. temperatures = temperatures[accepted_indices // num_draft_tokens] - - logprobs = torch.nn.functional.log_softmax( - logits_output.next_token_logits / temperatures, dim=-1 - ) + if RETURN_ORIGINAL_LOGPROB: + logprobs = torch.nn.functional.log_softmax( + logits_output.next_token_logits, dim=-1 + ) + else: + logprobs = torch.nn.functional.log_softmax( + logits_output.next_token_logits / temperatures, dim=-1 + ) batch_next_token_ids = res.verified_id num_tokens_per_req = [accept + 1 for accept in res.accept_length_per_req_cpu] @@ -813,13 +820,19 @@ class EAGLEWorker(TpModelWorker): ( logits_output.next_token_top_logprobs_val, logits_output.next_token_top_logprobs_idx, - ) = get_top_logprobs(logprobs, top_logprobs_nums_repeat_interleaved) + ) = get_top_logprobs( + logprobs, + top_logprobs_nums_repeat_interleaved, + ) if any(x is not None for x in token_ids_logprobs): ( logits_output.next_token_token_ids_logprobs_val, logits_output.next_token_token_ids_logprobs_idx, - ) = get_token_ids_logprobs(logprobs, token_ids_logprobs_repeat_interleaved) + ) = get_token_ids_logprobs( + logprobs, + token_ids_logprobs_repeat_interleaved, + ) logits_output.next_token_logprobs = logprobs[ torch.arange(len(batch_next_token_ids), device=batch.sampling_info.device), diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 2b1ef4c53..cd219f082 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -87,6 +87,7 @@ suites = { TestFile("test_mla_fp8.py", 93), TestFile("test_no_chunked_prefill.py", 108), TestFile("test_no_overlap_scheduler.py", 234), + TestFile("test_original_logprobs.py", 200), TestFile("test_penalty.py", 41), TestFile("test_page_size.py", 60), TestFile("test_pytorch_sampling_backend.py", 66), diff --git a/test/srt/test_original_logprobs.py b/test/srt/test_original_logprobs.py new file mode 100644 index 000000000..ddcfe3d8e --- /dev/null +++ b/test/srt/test_original_logprobs.py @@ -0,0 +1,196 @@ +"""Test original log probability alignment between SGLang and Hugging Face. + +This test suite verifies the correctness of the `origin_logprobs` output (temperature=1) +and the `logprobs` output (temperature=0.5) in SGLang by comparing it against +raw logit-based probabilities computed directly from a reference Hugging Face model. + +The test covers the following scenarios: +- Next-token prediction: Verifies that the log probability of the next token from + SGLang matches the Hugging Face model. +- Top-k logprobs: Ensures that the top-k original logprobs returned by SGLang are + consistent with Hugging Face outputs. +- Specified token IDs: Confirms that the original logprobs for specific token IDs + match the values computed from Hugging Face logits. +""" + +import os +import random +import unittest + +import numpy as np +import torch +import torch.nn.functional as F +from transformers import AutoModelForCausalLM, AutoTokenizer + +import sglang as sgl +from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST + +# ------------------------- Configurable via env ------------------------- # +MODEL_ID = DEFAULT_SMALL_MODEL_NAME_FOR_TEST +PROMPTS = [ + "Hello, my name is", + "The future of AI is", + "The president of the United States is", + "The capital of France is ", +] +TOP_LOGPROBS_NUM = 50 +NUM_RANDOM_TOKEN_IDS = 10 +RTOL = 0.20 +ATOL = 0.00 +# ------------------------------------------------ + +torch.manual_seed(1234) +if torch.cuda.is_available(): + torch.cuda.manual_seed_all(1234) + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + + +class TestOriginalLogprob(unittest.TestCase): + def setUp(self): + # ----- HF side (float32 weights) ----- + self.tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, padding_side="right") + self.hf_model = AutoModelForCausalLM.from_pretrained( + MODEL_ID, torch_dtype=torch.float32, device_map="auto" + ) + + # Shared sampling parameters + self.sampling_params = { + "temperature": 0.5, # SGLang uses 0.5, but original logprobs are used 1.0 + "top_p": 1.0, + "top_k": 10, + "max_new_tokens": 1, + } + + # --------------------------------------------------------------------- + # Helper: compare one SGLang block (token_logprobs / top_logprobs / ids_logprobs) + # against a reference HF log‑prob vector. + # --------------------------------------------------------------------- + def assert_logprobs_block_equal( + self, + hf_log_probs: torch.Tensor, # [V] + token_log_probs: list, + top_log_probs: list, + ids_log_probs: list, + random_token_ids: list, + tag: str = "", + ): + vals, idxs, _ = zip(*token_log_probs) + sgl_vals = torch.tensor(vals, device=self.hf_model.device, dtype=torch.float32) + sgl_idxs = torch.tensor(idxs, device=self.hf_model.device, dtype=torch.long) + hf_vals = hf_log_probs[sgl_idxs] + + self.assertTrue( + torch.allclose(hf_vals, sgl_vals, rtol=RTOL, atol=ATOL), + msg=f"[{tag}] token‑level mismatch at indices {sgl_idxs.tolist()}", + ) + + hf_topk, _ = torch.topk(hf_log_probs, k=TOP_LOGPROBS_NUM, dim=-1) + + sgl_topk = torch.tensor( + [float(t[0]) for t in top_log_probs[0] if t and t[0] is not None][ + :TOP_LOGPROBS_NUM + ], + dtype=torch.float32, + device=self.hf_model.device, + ) + + k = min(hf_topk.numel(), sgl_topk.numel()) + self.assertTrue( + torch.allclose(hf_topk[:k], sgl_topk[:k], rtol=RTOL, atol=ATOL), + msg=f"[{tag}] top‑k mismatch", + ) + + indices = torch.tensor( + random_token_ids, dtype=torch.long, device=hf_log_probs.device + ) + + hf_token_ids = hf_log_probs[indices] + + sgl_token_ids = torch.tensor( + [v for v, _, _ in ids_log_probs[0]], + device=self.hf_model.device, + dtype=torch.float32, + ) + self.assertTrue( + torch.allclose(hf_token_ids, sgl_token_ids, rtol=RTOL, atol=ATOL), + msg=f"[{tag}] token‑IDs mismatch", + ) + + # Optional: print max abs diff for quick diagnostics + max_diff = torch.max(torch.abs(hf_vals - sgl_vals)).item() + print(f"[{tag}] max|diff| token‑level = {max_diff:.4f}") + + def test_logprob_match(self): + vocab_size = self.tokenizer.vocab_size + + for env_val in ["True", "False"]: + with self.subTest(return_original_logprob=env_val): + os.environ["RETURN_ORIGINAL_LOGPROB"] = env_val + + # ----- SGLang side ----- + sgl_engine = sgl.Engine( + model_path=MODEL_ID, + skip_tokenizer_init=True, + trust_remote_code=True, + mem_fraction_static=0.60, + ) + + for prompt in PROMPTS: + random_token_ids = sorted( + random.sample(range(vocab_size), NUM_RANDOM_TOKEN_IDS) + ) + + enc = self.tokenizer(prompt, return_tensors="pt") + input_ids = enc["input_ids"].to(self.hf_model.device) + attn_mask = enc["attention_mask"].to(self.hf_model.device) + + with torch.inference_mode(): + hf_out = self.hf_model( + input_ids=input_ids, + attention_mask=attn_mask, + return_dict=True, + ) + logits = hf_out.logits[:, -1, :] # [1, V] + hf_log_probs = F.log_softmax( + logits.float() / self.sampling_params["temperature"], dim=-1 + )[0] + hf_original_log_probs = F.log_softmax(logits.float(), dim=-1)[0] + + outputs = sgl_engine.generate( + input_ids=input_ids[0].tolist(), + sampling_params=self.sampling_params, + return_logprob=True, + top_logprobs_num=TOP_LOGPROBS_NUM, + token_ids_logprob=random_token_ids, + ) + + if isinstance(outputs, list): + outputs = outputs[0] + meta = outputs["meta_info"] + + # Check original logprobs only if enabled + if env_val.lower() == "true": + self.assert_logprobs_block_equal( + hf_log_probs=hf_original_log_probs, + token_log_probs=meta["output_token_logprobs"], + top_log_probs=meta["output_top_logprobs"], + ids_log_probs=meta["output_token_ids_logprobs"], + random_token_ids=random_token_ids, + tag=f"Original logprobs SGLang vs HF: {prompt} ({env_val})", + ) + else: + # Always check regular logprobs + self.assert_logprobs_block_equal( + hf_log_probs=hf_log_probs, + token_log_probs=meta["output_token_logprobs"], + top_log_probs=meta["output_top_logprobs"], + ids_log_probs=meta["output_token_ids_logprobs"], + random_token_ids=random_token_ids, + tag=f"logprobs SGLang vs HF: {prompt} ({env_val})", + ) + sgl_engine.shutdown() + + +if __name__ == "__main__": + unittest.main()