diff --git a/python/sglang/lang/backend/runtime_endpoint.py b/python/sglang/lang/backend/runtime_endpoint.py index 1261b6d0c..a00325912 100644 --- a/python/sglang/lang/backend/runtime_endpoint.py +++ b/python/sglang/lang/backend/runtime_endpoint.py @@ -251,11 +251,12 @@ class RuntimeEndpoint(BaseBackend): } obj = self._generate_http_request(s, data) - normalized_prompt_logprobs = [ - r["meta_info"]["normalized_prompt_logprob"] for r in obj - ] input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj] output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj] + normalized_prompt_logprobs = [ + compute_normalized_prompt_logprobs(r["meta_info"]["input_token_logprobs"]) + for r in obj + ] # Remove extra token if no token healing occurred for i in range(len(input_token_logprobs)): @@ -319,3 +320,8 @@ class RuntimeEndpoint(BaseBackend): def _assert_success(self, res): if res.status_code != 200: raise RuntimeError(res.json()) + + +def compute_normalized_prompt_logprobs(input_logprobs): + values = [x[0] for x in input_logprobs if x[0]] + return sum(values) / len(values) diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 7ca1d51a7..f5b12b48a 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -50,8 +50,6 @@ class LogitsProcessorOutput: next_token_top_logprobs_idx: Optional[List] = None ## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor - # The normlaized logprobs of prompts. shape: [#seq] - normalized_prompt_logprobs: torch.Tensor = None # The logprobs of input tokens. shape: [#token] input_token_logprobs: torch.Tensor = None # The logprobs and ids of the top-k tokens in input positions. shape: [#seq, #token, k] @@ -195,8 +193,6 @@ class LogitsProcessor(nn.Module): else: input_top_logprobs_val = input_top_logprobs_idx = None - # Compute the normalized logprobs for the requested tokens. - # Note that we pad a zero at the end for easy batching. input_token_logprobs = input_logprobs[ torch.arange(input_logprobs.shape[0], device="cuda"), torch.cat( @@ -206,14 +202,9 @@ class LogitsProcessor(nn.Module): ] ), ] - normalized_prompt_logprobs = self._get_normalized_prompt_logprobs( - input_token_logprobs, - logits_metadata, - ) return LogitsProcessorOutput( next_token_logits=last_logits, - normalized_prompt_logprobs=normalized_prompt_logprobs, input_token_logprobs=input_token_logprobs, input_top_logprobs_val=input_top_logprobs_val, input_top_logprobs_idx=input_top_logprobs_idx, @@ -237,8 +228,6 @@ class LogitsProcessor(nn.Module): if self.do_tensor_parallel_all_gather: logits = tensor_model_parallel_all_gather(logits) - # Compute the normalized logprobs for the requested tokens. - # Note that we pad a zero at the end for easy batching. logits = logits[:, : self.config.vocab_size].float() if self.final_logit_softcapping: @@ -246,27 +235,6 @@ class LogitsProcessor(nn.Module): return logits - @staticmethod - def _get_normalized_prompt_logprobs( - input_token_logprobs: torch.Tensor, - logits_metadata: LogitsMetadata, - ): - logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32) - pruned_lens = torch.tensor( - logits_metadata.extend_logprob_pruned_lens_cpu, device="cuda" - ) - - start = torch.zeros_like(pruned_lens) - start[1:] = torch.cumsum(pruned_lens[:-1], dim=0) - end = torch.clamp( - start + pruned_lens - 2, min=0, max=logprobs_cumsum.shape[0] - 1 - ) - sum_logp = ( - logprobs_cumsum[end] - logprobs_cumsum[start] + input_token_logprobs[start] - ) - normalized_prompt_logprobs = sum_logp / (pruned_lens - 1).clamp(min=1) - return normalized_prompt_logprobs - @staticmethod def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata): max_k = max(logits_metadata.top_logprobs_nums) diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index b4bc1e7a4..7a0f7b0d5 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -191,7 +191,6 @@ class DetokenizerManager: input_top_logprobs_idx=recv_obj.input_top_logprobs_idx, output_top_logprobs_val=recv_obj.output_top_logprobs_val, output_top_logprobs_idx=recv_obj.output_top_logprobs_idx, - normalized_prompt_logprob=recv_obj.normalized_prompt_logprob, ) ) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 075693c7b..1698dfbeb 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -340,7 +340,6 @@ class BatchTokenIDOut: input_top_logprobs_idx: List[List] output_top_logprobs_val: List[List] output_top_logprobs_idx: List[List] - normalized_prompt_logprob: List[float] @dataclass @@ -366,7 +365,6 @@ class BatchStrOut: input_top_logprobs_idx: List[List] output_top_logprobs_val: List[List] output_top_logprobs_idx: List[List] - normalized_prompt_logprob: List[float] @dataclass diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 3b056cc5d..c375df234 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -280,7 +280,6 @@ class Req: self.top_logprobs_num = top_logprobs_num # Logprobs (return value) - self.normalized_prompt_logprob = None self.input_token_logprobs_val = None self.input_token_logprobs_idx = None self.input_top_logprobs_val = None @@ -344,9 +343,6 @@ class Req: max_prefix_len = min(max_prefix_len, input_len - 1) if self.return_logprob: - if self.normalized_prompt_logprob is None: - # Need at least two tokens to compute normalized logprob - max_prefix_len = min(max_prefix_len, input_len - 2) max_prefix_len = min(max_prefix_len, self.logprob_start_len) max_prefix_len = max(max_prefix_len, 0) diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py index d2083d092..7cab55c74 100644 --- a/python/sglang/srt/managers/schedule_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -433,7 +433,6 @@ class PrefillAdder: or input_tokens <= self.rem_chunk_tokens or ( req.return_logprob - and req.normalized_prompt_logprob is None and req.logprob_start_len != len(req.origin_input_ids) - 1 ) ): diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 187216353..169c202d3 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1038,9 +1038,6 @@ class Scheduler: logits_output.input_token_logprobs = ( logits_output.input_token_logprobs.tolist() ) - logits_output.normalized_prompt_logprobs = ( - logits_output.normalized_prompt_logprobs.tolist() - ) # Check finish conditions logprob_pt = 0 @@ -1188,9 +1185,6 @@ class Scheduler: # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored. num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len - if req.normalized_prompt_logprob is None: - req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i] - if req.input_token_logprobs_val is None: input_token_logprobs_val = output.input_token_logprobs[ pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens @@ -1288,15 +1282,12 @@ class Scheduler: input_top_logprobs_idx = [] output_top_logprobs_val = [] output_top_logprobs_idx = [] - normalized_prompt_logprob = [] else: input_token_logprobs_val = input_token_logprobs_idx = ( output_token_logprobs_val ) = output_token_logprobs_idx = input_top_logprobs_val = ( input_top_logprobs_idx - ) = output_top_logprobs_val = output_top_logprobs_idx = ( - normalized_prompt_logprob - ) = None + ) = output_top_logprobs_val = output_top_logprobs_idx = None for req in reqs: if req is skip_req: @@ -1343,7 +1334,6 @@ class Scheduler: input_top_logprobs_idx.append(req.input_top_logprobs_idx) output_top_logprobs_val.append(req.output_top_logprobs_val) output_top_logprobs_idx.append(req.output_top_logprobs_idx) - normalized_prompt_logprob.append(req.normalized_prompt_logprob) # Send to detokenizer if rids: @@ -1370,7 +1360,6 @@ class Scheduler: input_top_logprobs_idx, output_top_logprobs_val, output_top_logprobs_idx, - normalized_prompt_logprob, ) ) else: # embedding or reward model diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index eae3d87d7..4f4e4f7dc 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -796,9 +796,6 @@ class TokenizerManager: recv_obj.output_token_logprobs_idx[recv_obj_index], return_text_in_logprobs, ) - meta_info["normalized_prompt_logprob"] = recv_obj.normalized_prompt_logprob[ - recv_obj_index - ] if top_logprobs_num > 0: meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens( diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index 4c98c6be2..2aa9c8269 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -151,11 +151,6 @@ class TpModelWorkerClient: logits_output.input_token_logprobs = ( logits_output.input_token_logprobs.to("cpu", non_blocking=True) ) - logits_output.normalized_prompt_logprobs = ( - logits_output.normalized_prompt_logprobs.to( - "cpu", non_blocking=True - ) - ) next_token_ids = next_token_ids.to("cpu", non_blocking=True) copy_done.record() @@ -174,9 +169,6 @@ class TpModelWorkerClient: logits_output.input_token_logprobs = ( logits_output.input_token_logprobs.tolist() ) - logits_output.normalized_prompt_logprobs = ( - logits_output.normalized_prompt_logprobs.tolist() - ) next_token_ids = next_token_ids.tolist() return logits_output, next_token_ids diff --git a/python/sglang/test/test_programs.py b/python/sglang/test/test_programs.py index 411a20b92..219ed3cf6 100644 --- a/python/sglang/test/test_programs.py +++ b/python/sglang/test/test_programs.py @@ -535,7 +535,7 @@ def test_hellaswag_select(): # Compute accuracy accuracy_gen = np.mean(np.array(preds_gen) == np.array(labels)) - assert np.abs(accuracy_gen - accuracy) < 0.01 + assert np.abs(accuracy_gen - accuracy) < 0.05 assert np.abs(latency_gen - latency) < 1 return accuracy, latency diff --git a/scripts/deprecated/test_httpserver_classify.py b/scripts/deprecated/test_httpserver_classify.py deleted file mode 100644 index cb8880299..000000000 --- a/scripts/deprecated/test_httpserver_classify.py +++ /dev/null @@ -1,85 +0,0 @@ -""" -Usage: -python3 -m sglang.launch_server --model-path /model/llama-classification --is-embedding --disable-radix-cache - -python3 test_httpserver_classify.py -""" - -import argparse - -import numpy as np -import requests - - -def get_logits_deprecated(url: str, prompt: str): - response = requests.post( - url + "/generate", - json={ - "text": prompt, - "sampling_params": { - "max_new_tokens": 0, - }, - "return_logprob": True, - }, - ) - return response.json()["meta_info"]["normalized_prompt_logprob"] - - -def get_logits_batch_deprecated(url: str, prompts: list[str]): - response = requests.post( - url + "/generate", - json={ - "text": prompts, - "sampling_params": { - "max_new_tokens": 0, - }, - "return_logprob": True, - }, - ) - ret = response.json() - logits = np.array( - list( - ret[i]["meta_info"]["normalized_prompt_logprob"] - for i in range(len(prompts)) - ) - ) - return logits - - -def get_logits(url: str, prompt: str): - response = requests.post( - url + "/classify", - json={"text": prompt}, - ) - return response.json()["embedding"] - - -def get_logits_batch(url: str, prompts: list[str]): - response = requests.post( - url + "/classify", - json={"text": prompts}, - ) - return np.array([x["embedding"] for x in response.json()]) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--host", type=str, default="http://127.0.0.1") - parser.add_argument("--port", type=int, default=30000) - args = parser.parse_args() - - url = f"{args.host}:{args.port}" - - # A single request - prompt = "This is a test prompt.<|eot_id|>" - logits = get_logits(url, prompt) - print(f"{logits=}") - - # A batch of requests - prompts = [ - "This is a test prompt.<|eot_id|>", - "This is another test prompt.<|eot_id|>", - "This is a long long long long test prompt.<|eot_id|>", - ] - logits = get_logits_batch(url, prompts) - print(f"{logits=}") diff --git a/scripts/deprecated/test_httpserver_decode_stream.py b/scripts/deprecated/test_httpserver_decode_stream.py index 955c368d1..616eaf6c4 100644 --- a/scripts/deprecated/test_httpserver_decode_stream.py +++ b/scripts/deprecated/test_httpserver_decode_stream.py @@ -42,7 +42,6 @@ def test_decode_stream(url, return_logprob, top_logprobs_num): if return_logprob: assert data["meta_info"]["input_token_logprobs"] is not None assert data["meta_info"]["output_token_logprobs"] is not None - assert data["meta_info"]["normalized_prompt_logprob"] is not None for logprob, token_id, token_text in data["meta_info"][ "output_token_logprobs" ][prev:]: