diff --git a/docs/sampling_params.md b/docs/sampling_params.md index ed85b935f..add849d21 100644 --- a/docs/sampling_params.md +++ b/docs/sampling_params.md @@ -14,10 +14,14 @@ class GenerateReqInput: sampling_params: Union[List[Dict], Dict] = None # The request id rid: Optional[Union[List[str], str]] = None - # Whether return logprobs of the prompts + # Whether to return logprobs return_logprob: Optional[Union[List[bool], bool]] = None # The start location of the prompt for return_logprob logprob_start_len: Optional[Union[List[int], int]] = None + # The number of top logprobs to return + top_logprobs_num: Optional[Union[List[int], int]] = None + # Whether to detokenize tokens in logprobs + return_text_in_logprobs: bool = False # Whether to stream output stream: bool = False ``` diff --git a/examples/usage/choices_logprob.py b/examples/usage/choices_logprob.py index 6fb28940c..e261668f8 100644 --- a/examples/usage/choices_logprob.py +++ b/examples/usage/choices_logprob.py @@ -3,6 +3,7 @@ Usage: python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 python choices_logprob.py """ + import sglang as sgl @@ -19,9 +20,9 @@ def main(): print("questions:", question) print("choice:", state["tool"]) meta_info = state.get_meta_info("tool") - print("logprobs of choice 1", meta_info["prompt_logprob"][0]) - print("logprobs of choice 2", meta_info["prompt_logprob"][1]) - print('-' * 50) + print("logprobs of choice 1", meta_info["prefill_token_logprobs"][0]) + print("logprobs of choice 2", meta_info["prefill_token_logprobs"][1]) + print("-" * 50) # Run a batch questions = [ @@ -33,9 +34,9 @@ def main(): print("questions:", question) print("choice:", state["tool"]) meta_info = state.get_meta_info("tool") - print("logprobs of choice 1", meta_info["prompt_logprob"][0]) - print("logprobs of choice 2", meta_info["prompt_logprob"][1]) - print('-' * 50) + print("logprobs of choice 1", meta_info["prefill_token_logprobs"][0]) + print("logprobs of choice 2", meta_info["prefill_token_logprobs"][1]) + print("-" * 50) if __name__ == "__main__": diff --git a/python/sglang/backend/runtime_endpoint.py b/python/sglang/backend/runtime_endpoint.py index 3d2ecaa76..899ba09e2 100644 --- a/python/sglang/backend/runtime_endpoint.py +++ b/python/sglang/backend/runtime_endpoint.py @@ -213,6 +213,7 @@ class RuntimeEndpoint(BaseBackend): "sampling_params": {"max_new_tokens": 0}, "return_logprob": True, "logprob_start_len": max(prompt_len - 2, 0), + "return_text_in_logprobs": True, } self._add_images(s, data) res = http_request( @@ -224,13 +225,19 @@ class RuntimeEndpoint(BaseBackend): ) assert res.status_code == 200 obj = res.json() - normalized_prompt_logprob = [ + normalized_prompt_logprobs = [ r["meta_info"]["normalized_prompt_logprob"] for r in obj ] - prompt_logprob = [r["meta_info"]["prompt_logprob"] for r in obj] + decision = choices[np.argmax(normalized_prompt_logprobs)] + prefill_token_logprobs = [r["meta_info"]["prefill_token_logprobs"] for r in obj] + decode_token_logprobs = [r["meta_info"]["decode_token_logprobs"] for r in obj] - decision = choices[np.argmax(normalized_prompt_logprob)] - return decision, normalized_prompt_logprob, prompt_logprob + return ( + decision, + normalized_prompt_logprobs, + prefill_token_logprobs, + decode_token_logprobs, + ) def concatenate_and_append(self, src_rids: List[str], dst_rid: str): res = http_request( diff --git a/python/sglang/lang/interpreter.py b/python/sglang/lang/interpreter.py index 08a8d401b..22b6106da 100644 --- a/python/sglang/lang/interpreter.py +++ b/python/sglang/lang/interpreter.py @@ -454,15 +454,19 @@ class StreamExecutor: self.stream_var_event[name].set() def _execute_select(self, expr: SglSelect): - decision, normalized_prompt_logprob, prompt_logprob = self.backend.select( - self, expr.choices, expr.temperature - ) + ( + decision, + normalized_prompt_logprobs, + prefill_token_logprobs, + decode_token_logprobs, + ) = self.backend.select(self, expr.choices, expr.temperature) if expr.name is not None: name = expr.name self.variables[name] = decision self.meta_info[name] = { - "normalized_prompt_logprob": normalized_prompt_logprob, - "prompt_logprob": prompt_logprob, + "normalized_prompt_logprobs": normalized_prompt_logprobs, + "prefill_token_logprobs": prefill_token_logprobs, + "decode_token_logprobs": decode_token_logprobs, } self.variable_event[name].set() self.text_ += decision diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 980a2cd20..f96471e63 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -13,76 +13,127 @@ class LogitsProcessor(nn.Module): self.config = config self.tp_size = get_tensor_model_parallel_world_size() - def forward(self, input_ids, hidden_states, weight, input_metadata): - last_index = None + def _get_normalized_prompt_logprobs( + self, prefill_token_logprobs, input_metadata: InputMetadata + ): + logprobs_cumsum = torch.cumsum( + prefill_token_logprobs, dim=0, dtype=torch.float32 + ) - # Compute the last index (the first decode token) of each requeast - # if we are in prefill or extend mode. + start = input_metadata.extend_start_loc.clone() + end = start + input_metadata.extend_seq_lens - 2 + start.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1) + end.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1) + sum_logp = ( + logprobs_cumsum[end] + - logprobs_cumsum[start] + + prefill_token_logprobs[start] + ) + normalized_prompt_logprobs = sum_logp / ( + (input_metadata.extend_seq_lens - 1).clamp(min=1) + ) + + return normalized_prompt_logprobs + + def _get_top_logprobs(self, all_logprobs, input_metadata: InputMetadata): + if input_metadata.forward_mode == ForwardMode.DECODE: + decode_top_logprobs = [] + for i in range(all_logprobs.shape[0]): + k = input_metadata.top_logprobs_nums[i] + t = all_logprobs[i].topk(k) + v_cpu = t.values.cpu().tolist() + p_cpu = t.indices.cpu().tolist() + decode_top_logprobs.append(list(zip(v_cpu, p_cpu))) + return None, decode_top_logprobs + else: + prefill_top_logprobs, decode_top_logprobs = [], [] + pt = 0 + # NOTE: the GPU-CPU overhead can be reduced + extend_seq_lens_cpu = input_metadata.extend_seq_lens + for i in range(len(input_metadata.extend_seq_lens)): + if extend_seq_lens_cpu[i] == 0: + continue + k = input_metadata.top_logprobs_nums[i] + t = all_logprobs[pt : pt + extend_seq_lens_cpu[i]].topk(k) + vs_cpu = t.values.cpu().tolist() + ps_cpu = t.indices.cpu().tolist() + prefill_top_logprobs.append( + [list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)] + ) + decode_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1]))) + return prefill_top_logprobs, decode_top_logprobs + + def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadata): + # Get last index for next token prediction, except for DECODE mode. + last_index = None if input_metadata.forward_mode != ForwardMode.DECODE: last_index = ( - torch.cumsum( - input_metadata.seq_lens - input_metadata.prefix_lens, - dim=0, - dtype=torch.long, - ) + torch.cumsum(input_metadata.extend_seq_lens, dim=0, dtype=torch.long) - 1 ) - if not input_metadata.return_logprob: - # When logprob is not requested, only compute the last logits. - if input_metadata.forward_mode == ForwardMode.DECODE: - last_hidden = hidden_states - else: - last_hidden = hidden_states[last_index] - hidden_states = None + # Get the last hidden states and last logits + if input_metadata.forward_mode == ForwardMode.DECODE: + last_hidden = hidden_states + else: + last_hidden = hidden_states[last_index] - last_logits = torch.matmul(last_hidden, weight.T) - if self.tp_size > 1: - last_logits = tensor_model_parallel_all_gather(last_logits) - last_logits = last_logits[:, : self.config.vocab_size] - return last_logits, (None, None, None) + last_logits = torch.matmul(last_hidden, weight.T) + if self.tp_size > 1: + last_logits = tensor_model_parallel_all_gather(last_logits) + last_logits = last_logits[:, : self.config.vocab_size] + + # Return only last_logits if logprob is not requested + if not input_metadata.return_logprob: + hidden_states = None + return last_logits, (None, None, None, None, None) else: # When logprob is requested, compute the logits for all tokens. - logits = torch.matmul(hidden_states, weight.T) - if self.tp_size > 1: - logits = tensor_model_parallel_all_gather(logits) - logits = logits[:, : self.config.vocab_size] - all_logprobs = torch.log(torch.softmax(logits.float(), dim=-1) + 1e-6) + if input_metadata.forward_mode == ForwardMode.DECODE: + all_logits = last_logits + else: + all_logits = torch.matmul(hidden_states, weight.T) + if self.tp_size > 1: + all_logits = tensor_model_parallel_all_gather(all_logits) + all_logits = all_logits[:, : self.config.vocab_size] + + all_logprobs = torch.log(torch.softmax(all_logits.float(), dim=-1) + 1e-6) + + prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs( + all_logprobs, input_metadata + ) if input_metadata.forward_mode == ForwardMode.DECODE: - last_logits = logits last_logprobs = all_logprobs - prefill_logprobs = normalized_logprobs = None + return last_logits, ( + None, + None, + decode_top_logprobs, + None, + last_logprobs, + ) else: # Compute the logprobs for the last token of each request. - last_logits = logits[last_index] last_logprobs = all_logprobs[last_index] # Compute the logprobs and normalized logprobs for the prefill tokens. # Note that we pad a zero at the end of each sequence for easy computation. - prefill_logprobs = all_logprobs[ + prefill_token_logprobs = all_logprobs[ torch.arange(all_logprobs.shape[0], device="cuda"), torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]), ] - logprobs_cumsum = torch.cumsum( - prefill_logprobs, dim=0, dtype=torch.float32 - ) - start = input_metadata.extend_start_loc.clone() - end = start + input_metadata.extend_seq_lens - 2 - start.clamp_(min=0, max=prefill_logprobs.shape[0] - 1) - end.clamp_(min=0, max=prefill_logprobs.shape[0] - 1) - sum_logp = ( - logprobs_cumsum[end] - - logprobs_cumsum[start] - + prefill_logprobs[start] + normalized_prompt_logprobs = self._get_normalized_prompt_logprobs( + prefill_token_logprobs, input_metadata ) - normalized_logprobs = sum_logp / ( - (input_metadata.extend_seq_lens - 1).clamp(min=1) + return last_logits, ( + prefill_token_logprobs, + prefill_top_logprobs, + decode_top_logprobs, + normalized_prompt_logprobs, + last_logprobs, ) - return last_logits, (prefill_logprobs, normalized_logprobs, last_logprobs) - if __name__ == "__main__": all_logprobs = torch.tensor( @@ -93,23 +144,22 @@ if __name__ == "__main__": ) seq_lens = torch.tensor([2, 0, 3, 0], dtype=torch.int32, device="cuda") input_ids = torch.tensor([1, 2, 3, 0, 1], dtype=torch.int32, device="cuda") - logprobs = torch.zeros(5, dtype=torch.float32, device="cuda") - logprobs = all_logprobs[ + token_logprobs = all_logprobs[ torch.arange(all_logprobs.shape[0], device="cuda"), torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]), ] - logprobs_cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32) + logprobs_cumsum = torch.cumsum(token_logprobs, dim=0, dtype=torch.float32) len_cumsum = torch.cumsum(seq_lens, dim=0) start = torch.cat((torch.tensor([0], device="cuda"), len_cumsum[:-1]), 0) end = start + seq_lens - 2 - start.clamp_(min=0, max=logprobs.shape[0] - 1) - end.clamp_(min=0, max=logprobs.shape[0] - 1) - sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + logprobs[start] + start.clamp_(min=0, max=token_logprobs.shape[0] - 1) + end.clamp_(min=0, max=token_logprobs.shape[0] - 1) + sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + token_logprobs[start] # assert logprobs == [2, _, 2, 4, _] - print("logprobs", logprobs) + print("token logprobs", token_logprobs) print("start", start) print("end", end) print("sum_logp", sum_logp) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index b6817994a..53b1f552a 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -19,10 +19,13 @@ class GenerateReqInput: return_logprob: Optional[Union[List[bool], bool]] = None # The start location of the prompt for return_logprob logprob_start_len: Optional[Union[List[int], int]] = None + # The number of top logprobs to return + top_logprobs_num: Optional[Union[List[int], int]] = None # Whether to detokenize tokens in logprobs return_text_in_logprobs: bool = False # Whether to stream output stream: bool = False + # TODO: make all parameters a Union[List[T], T] to allow for batched requests def post_init(self): is_single = isinstance(self.text, str) @@ -36,6 +39,8 @@ class GenerateReqInput: self.return_logprob = False if self.logprob_start_len is None: self.logprob_start_len = 0 + if self.top_logprobs_num is None: + self.top_logprobs_num = 0 else: num = len(self.text) @@ -64,6 +69,11 @@ class GenerateReqInput: elif not isinstance(self.logprob_start_len, list): self.logprob_start_len = [self.logprob_start_len] * num + if self.top_logprobs_num is None: + self.top_logprobs_num = [0] * num + elif not isinstance(self.top_logprobs_num, list): + self.top_logprobs_num = [self.top_logprobs_num] * num + @dataclass class TokenizedGenerateReqInput: @@ -76,6 +86,7 @@ class TokenizedGenerateReqInput: sampling_params: SamplingParams return_logprob: bool logprob_start_len: int + top_logprobs_num: int stream: bool diff --git a/python/sglang/srt/managers/router/infer_batch.py b/python/sglang/srt/managers/router/infer_batch.py index da5cab42d..f001075bc 100644 --- a/python/sglang/srt/managers/router/infer_batch.py +++ b/python/sglang/srt/managers/router/infer_batch.py @@ -43,6 +43,7 @@ class Req: self.sampling_params = None self.return_logprob = False self.logprob_start_len = 0 + self.top_logprobs_num = 0 self.stream = False self.tokenizer = None @@ -54,9 +55,11 @@ class Req: self.prefix_indices = [] self.last_node = None - self.logprob = None - self.token_logprob = None - self.normalized_logprob = None + self.prefill_token_logprobs = None + self.decode_token_logprobs = None + self.normalized_prompt_logprob = None + self.prefill_top_logprobs = None + self.decode_top_logprobs = None # For constrained decoding self.regex_fsm = None @@ -159,6 +162,9 @@ class Batch: out_cache_loc: torch.Tensor = None out_cache_cont_start: torch.Tensor = None out_cache_cont_end: torch.Tensor = None + + # for processing logprobs + top_logprobs_nums: List[int] = None return_logprob: bool = False # for multimodal @@ -266,6 +272,7 @@ class Batch: self.position_ids_offsets = position_ids_offsets self.extend_num_tokens = extend_num_tokens self.out_cache_loc = out_cache_loc + self.top_logprobs_nums = [r.top_logprobs_num for r in reqs] self.temperatures = torch.tensor( [r.sampling_params.temperature for r in reqs], @@ -415,6 +422,7 @@ class Batch: self.prefix_lens = None self.position_ids_offsets = self.position_ids_offsets[new_indices] self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None + self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices] self.return_logprob = any(req.return_logprob for req in self.reqs) for item in [ @@ -439,6 +447,7 @@ class Batch: [self.position_ids_offsets, other.position_ids_offsets] ) self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None + self.top_logprobs_nums.extend(other.top_logprobs_nums) self.return_logprob = any(req.return_logprob for req in self.reqs) for item in [ diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index 5c9be2095..75c152610 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -260,6 +260,7 @@ class ModelRpcServer: req.sampling_params = recv_req.sampling_params req.return_logprob = recv_req.return_logprob req.logprob_start_len = recv_req.logprob_start_len + req.top_logprobs_num = recv_req.top_logprobs_num req.stream = recv_req.stream req.tokenizer = self.tokenizer @@ -400,28 +401,36 @@ class ModelRpcServer: self.model_config.vocab_size, self.int_token_logit_bias ) - logprobs = None + prefill_token_logprobs = None if batch.extend_num_tokens != 0: # Forward logits, ( - prefill_logprobs, - normalized_logprobs, + prefill_token_logprobs, + prefill_top_logprobs, + decode_top_logprobs, + normalized_prompt_logprobs, last_logprobs, ) = self.model_runner.forward(batch, ForwardMode.EXTEND) - if prefill_logprobs is not None: - logprobs = prefill_logprobs.cpu().tolist() - normalized_logprobs = normalized_logprobs.cpu().tolist() + if prefill_token_logprobs is not None: + prefill_token_logprobs = prefill_token_logprobs.cpu().tolist() + normalized_prompt_logprobs = normalized_prompt_logprobs.cpu().tolist() next_token_ids, _ = batch.sample(logits) next_token_ids = next_token_ids.cpu().tolist() else: next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs) - logits = logprobs = normalized_logprobs = last_logprobs = None + ( + logits, + prefill_token_logprobs, + normalized_prompt_logprobs, + last_logprobs, + ) = (None,) * 4 # Only batch transfer the selected logprobs of the next token to CPU to reduce overhead. reqs = batch.reqs + last_token_logprobs = None if last_logprobs is not None: - last_logprobs = ( + last_token_logprobs = ( last_logprobs[torch.arange(len(reqs)), next_token_ids].cpu().tolist() ) @@ -432,18 +441,26 @@ class ModelRpcServer: req.output_ids = [next_token_ids[i]] req.check_finished() - if logprobs is not None: - req.logprob = logprobs[pt : pt + req.extend_input_len - 1] - req.normalized_logprob = normalized_logprobs[i] - - # If logprob_start_len > 0, then first logprob_start_len prompt tokens - # will be ignored. - prompt_token_len = len(req.logprob) - token_ids = req.input_ids[-prompt_token_len:] + [next_token_ids[i]] - token_logprobs = req.logprob + [last_logprobs[i]] - req.token_logprob = list(zip(token_ids, token_logprobs)) + if prefill_token_logprobs is not None: + # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored. + req.prefill_token_logprobs = list( + zip( + prefill_token_logprobs[pt : pt + req.extend_input_len - 1], + req.input_ids[-req.extend_input_len + 1 :], + ) + ) if req.logprob_start_len == 0: - req.token_logprob = [(req.input_ids[0], None)] + req.token_logprob + req.prefill_token_logprobs = [ + (None, req.input_ids[0]) + ] + req.prefill_token_logprobs + req.decode_token_logprobs = [ + (last_token_logprobs[i], next_token_ids[i]) + ] + req.prefill_top_logprobs = prefill_top_logprobs[i] + if req.logprob_start_len == 0: + req.prefill_top_logprobs = [None] + req.prefill_top_logprobs + req.decode_top_logprobs = [decode_top_logprobs[i]] + req.normalized_prompt_logprob = normalized_prompt_logprobs[i] pt += req.extend_input_len self.handle_finished_requests(batch) @@ -493,27 +510,29 @@ class ModelRpcServer: batch.prepare_for_decode() # Forward - logits, (_, _, last_logprobs) = self.model_runner.forward( - batch, ForwardMode.DECODE + logits, (_, _, decode_top_logprobs, _, last_logprobs) = ( + self.model_runner.forward(batch, ForwardMode.DECODE) ) next_token_ids, _ = batch.sample(logits) next_token_ids = next_token_ids.cpu().tolist() # Only batch transfer the selected logprobs of the next token to CPU to reduce overhead. reqs = batch.reqs + new_token_logprobs = None if last_logprobs is not None: - last_logprobs = last_logprobs[ + new_token_logprobs = last_logprobs[ torch.arange(len(reqs)), next_token_ids ].tolist() # Check finish condition - for i, (req, next_tok_id) in enumerate(zip(reqs, next_token_ids)): + for i, (req, next_token_id) in enumerate(zip(reqs, next_token_ids)): req.completion_tokens_wo_jump_forward += 1 - req.output_ids.append(next_tok_id) + req.output_ids.append(next_token_id) req.check_finished() - if last_logprobs is not None: - req.token_logprob.append((next_tok_id, last_logprobs[i])) + if new_token_logprobs is not None: + req.decode_token_logprobs.append((new_token_logprobs[i], next_token_id)) + req.decode_top_logprobs.append(decode_top_logprobs[i]) self.handle_finished_requests(batch) @@ -558,9 +577,19 @@ class ModelRpcServer: "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward, } if req.return_logprob: - meta_info["prompt_logprob"] = req.logprob - meta_info["token_logprob"] = req.token_logprob - meta_info["normalized_prompt_logprob"] = req.normalized_logprob + ( + meta_info["prefill_token_logprobs"], + meta_info["decode_token_logprobs"], + meta_info["prefill_top_logprobs"], + meta_info["decode_top_logprobs"], + meta_info["normalized_prompt_logprob"], + ) = ( + req.prefill_token_logprobs, + req.decode_token_logprobs, + req.prefill_top_logprobs, + req.decode_top_logprobs, + req.normalized_prompt_logprob, + ) output_meta_info.append(meta_info) output_finished.append(req.finished) diff --git a/python/sglang/srt/managers/router/model_runner.py b/python/sglang/srt/managers/router/model_runner.py index f349819f3..363289f73 100644 --- a/python/sglang/srt/managers/router/model_runner.py +++ b/python/sglang/srt/managers/router/model_runner.py @@ -5,6 +5,7 @@ import logging import pkgutil from dataclasses import dataclass from functools import lru_cache +from typing import List import numpy as np import torch @@ -81,6 +82,7 @@ class InputMetadata: out_cache_cont_end: torch.Tensor = None other_kv_index: torch.Tensor = None + top_logprobs_nums: List[int] = None return_logprob: bool = False # for flashinfer @@ -181,6 +183,7 @@ class InputMetadata: out_cache_loc, out_cache_cont_start=None, out_cache_cont_end=None, + top_logprobs_nums=None, return_logprob=False, ): batch_size = len(req_pool_indices) @@ -229,6 +232,7 @@ class InputMetadata: out_cache_loc=out_cache_loc, out_cache_cont_start=out_cache_cont_start, out_cache_cont_end=out_cache_cont_end, + top_logprobs_nums=top_logprobs_nums, return_logprob=return_logprob, other_kv_index=other_kv_index, ) @@ -377,6 +381,7 @@ class ModelRunner: prefix_lens=batch.prefix_lens, position_ids_offsets=batch.position_ids_offsets, out_cache_loc=batch.out_cache_loc, + top_logprobs_nums=batch.top_logprobs_nums, return_logprob=batch.return_logprob, ) return self.model.forward( @@ -394,6 +399,7 @@ class ModelRunner: prefix_lens=batch.prefix_lens, position_ids_offsets=batch.position_ids_offsets, out_cache_loc=batch.out_cache_loc, + top_logprobs_nums=batch.top_logprobs_nums, return_logprob=batch.return_logprob, ) return self.model.forward( @@ -413,6 +419,7 @@ class ModelRunner: out_cache_loc=batch.out_cache_loc, out_cache_cont_start=batch.out_cache_cont_start, out_cache_cont_end=batch.out_cache_cont_end, + top_logprobs_nums=batch.top_logprobs_nums, return_logprob=batch.return_logprob, ) return self.model.forward( @@ -430,6 +437,7 @@ class ModelRunner: prefix_lens=batch.prefix_lens, position_ids_offsets=batch.position_ids_offsets, out_cache_loc=batch.out_cache_loc, + top_logprobs_nums=batch.top_logprobs_nums, return_logprob=batch.return_logprob, ) return self.model.forward( diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 7947ca2ff..183c34bef 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -173,6 +173,7 @@ class TokenizerManager: sampling_params=sampling_params, return_logprob=obj.return_logprob, logprob_start_len=obj.logprob_start_len, + top_logprobs_num=obj.top_logprobs_num, stream=obj.stream, ) self.send_to_router.send_pyobj(tokenized_obj) @@ -215,6 +216,7 @@ class TokenizerManager: sampling_params=sampling_params, return_logprob=obj.return_logprob[i], logprob_start_len=obj.logprob_start_len[i], + top_logprobs_num=obj.top_logprobs_num[i], stream=obj.stream, ) self.send_to_router.send_pyobj(tokenized_obj) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index e9961305d..25de8e16c 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -123,31 +123,97 @@ async def flush_cache(): ) -async def detokenize_logprob_tokens(token_logprobs): - token_ids = [tid for tid, _ in token_logprobs] +async def detokenize_logprob_tokens(token_logprobs, decode_to_text): + if not decode_to_text: + return [(logprob, token_id, None) for logprob, token_id in token_logprobs] + + token_ids = [tid for _, tid in token_logprobs] token_texts = await tokenizer_manager.detokenize(DetokenizeReqInput(token_ids)) - return [(text, logprob) for text, (_, logprob) in zip(token_texts, token_logprobs)] + return [ + (logprob, token_id, token_text) + for (logprob, token_id), token_text, in zip(token_logprobs, token_texts) + ] + + +async def detokenize_top_logprobs_tokens(top_logprobs, decode_to_text): + for i, t in enumerate(top_logprobs): + if top_logprobs[i] is not None: + top_logprobs[i] = await detokenize_logprob_tokens(t, decode_to_text) + return top_logprobs + + +async def handle_token_logprobs_results(obj: GenerateReqInput, ret): + """Handle the token logprobs results, convert token ids to text if needed. + + Args: + obj (GenerateReqInput): The request object. + ret (Union[Dict, List[Dict]]): The response object. + """ + # NOTE: This is because the multiple requests in one http request. + + async def convert_style(r, return_text): + r["meta_info"]["prefill_token_logprobs"] = await detokenize_logprob_tokens( + r["meta_info"]["prefill_token_logprobs"], return_text + ) + r["meta_info"]["decode_token_logprobs"] = await detokenize_logprob_tokens( + r["meta_info"]["decode_token_logprobs"], return_text + ) + r["meta_info"]["prefill_top_logprobs"] = await detokenize_top_logprobs_tokens( + r["meta_info"]["prefill_top_logprobs"], return_text + ) + r["meta_info"]["decode_top_logprobs"] = await detokenize_top_logprobs_tokens( + r["meta_info"]["decode_top_logprobs"], return_text + ) + + if isinstance(obj.text, str): + if obj.return_logprob: + await convert_style(ret, obj.return_text_in_logprobs) + else: + for i, r in enumerate(ret): + if obj.return_logprob[i]: + await convert_style(r, obj.return_text_in_logprobs) async def stream_generator(obj: GenerateReqInput): async for out in tokenizer_manager.generate_request(obj): - if obj.return_logprob and obj.return_text_in_logprobs: - out["meta_info"]["token_logprob"] = await detokenize_logprob_tokens( - out["meta_info"]["token_logprob"] - ) + await handle_token_logprobs_results(obj, out) yield out -async def make_openai_style_logprobs(token_logprobs): +async def make_openai_style_logprobs( + prefill_token_logprobs=None, + decode_token_logprobs=None, + prefill_top_logprobs=None, + decode_top_logprobs=None, +): ret_logprobs = LogProbs() - for token_text, token_logprob in token_logprobs: - ret_logprobs.tokens.append(token_text) - ret_logprobs.token_logprobs.append(token_logprob) + def append_token_logprobs(token_logprobs): + for logprob, _, token_text in token_logprobs: + ret_logprobs.tokens.append(token_text) + ret_logprobs.token_logprobs.append(logprob) + + # Not Supported yet + ret_logprobs.text_offset.append(-1) + + def append_top_logprobs(top_logprobs): + for tokens in top_logprobs: + if tokens is not None: + ret_logprobs.top_logprobs.append( + {token[2]: token[0] for token in tokens} + ) + else: + ret_logprobs.top_logprobs.append(None) + + if prefill_token_logprobs is not None: + append_token_logprobs(prefill_token_logprobs) + if decode_token_logprobs is not None: + append_token_logprobs(decode_token_logprobs) + if prefill_top_logprobs is not None: + append_top_logprobs(prefill_top_logprobs) + if decode_top_logprobs is not None: + append_top_logprobs(decode_top_logprobs) - # Not supported yet. - ret_logprobs.top_logprobs.append({}) - ret_logprobs.text_offset.append(-1) return ret_logprobs @@ -165,10 +231,7 @@ async def generate_request(obj: GenerateReqInput): return StreamingResponse(stream_results(), media_type="text/event-stream") ret = await tokenizer_manager.generate_request(obj).__anext__() - if obj.return_logprob and obj.return_text_in_logprobs: - ret["meta_info"]["token_logprob"] = await detokenize_logprob_tokens( - ret["meta_info"]["token_logprob"] - ) + await handle_token_logprobs_results(obj, ret) return ret @@ -192,7 +255,8 @@ async def v1_completions(raw_request: Request): "frequency_penalty": request.frequency_penalty, "regex": request.regex, }, - return_logprob=request.logprobs is not None, + return_logprob=request.logprobs is not None and request.logprobs > 0, + top_logprobs_num=request.logprobs if request.logprobs is not None else 0, return_text_in_logprobs=True, stream=request.stream, ) @@ -212,15 +276,32 @@ async def v1_completions(raw_request: Request): if request.echo: # Prepend prompt in response text. text = request.prompt + text - else: - # Skip prompt tokens if echo is disabled. - n_prev_token = prompt_tokens - if request.logprobs is not None: + if request.logprobs: + # The first chunk and echo is enabled. + if not stream_buffer and request.echo: + prefill_token_logprobs = content["meta_info"][ + "prefill_token_logprobs" + ] + prefill_top_logprobs = content["meta_info"][ + "prefill_top_logprobs" + ] + else: + prefill_token_logprobs = None + prefill_top_logprobs = None + logprobs = await make_openai_style_logprobs( - content["meta_info"]["token_logprob"][n_prev_token:] + prefill_token_logprobs=prefill_token_logprobs, + prefill_top_logprobs=prefill_top_logprobs, + decode_token_logprobs=content["meta_info"][ + "decode_token_logprobs" + ][n_prev_token:], + decode_top_logprobs=content["meta_info"]["decode_top_logprobs"][ + n_prev_token: + ], ) - n_prev_token = len(content["meta_info"]["token_logprob"]) + + n_prev_token = len(content["meta_info"]["decode_token_logprobs"]) else: logprobs = None @@ -255,20 +336,26 @@ async def v1_completions(raw_request: Request): prompt_tokens = ret["meta_info"]["prompt_tokens"] completion_tokens = ret["meta_info"]["completion_tokens"] text = ret["text"] - token_logprob_pos = prompt_tokens if request.echo: - token_logprob_pos = 0 text = request.prompt + text - else: - token_logprob_pos = prompt_tokens - logprobs = ( - await make_openai_style_logprobs( - ret["meta_info"]["token_logprob"][token_logprob_pos:] + if request.logprobs: + if request.echo: + prefill_token_logprobs = ret["meta_info"]["prefill_token_logprobs"] + prefill_top_logprobs = ret["meta_info"]["prefill_top_logprobs"] + else: + prefill_token_logprobs = None + prefill_top_logprobs = None + + logprobs = await make_openai_style_logprobs( + prefill_token_logprobs=prefill_token_logprobs, + prefill_top_logprobs=prefill_top_logprobs, + decode_token_logprobs=ret["meta_info"]["decode_token_logprobs"], + decode_top_logprobs=ret["meta_info"]["decode_top_logprobs"], ) - if request.logprobs is not None - else None - ) + else: + logprobs = None + choice_data = CompletionResponseChoice( index=0, text=text, diff --git a/test/srt/test_httpserver_decode.py b/test/srt/test_httpserver_decode.py index fac301f6a..04897b398 100644 --- a/test/srt/test_httpserver_decode.py +++ b/test/srt/test_httpserver_decode.py @@ -9,11 +9,12 @@ The capital of France is Paris.\nThe capital of the United States is Washington, """ import argparse +import json import requests -def test_decode(url, return_logprob): +def test_decode(url, return_logprob, top_logprobs_num, return_text): response = requests.post( url + "/generate", json={ @@ -23,10 +24,13 @@ def test_decode(url, return_logprob): "max_new_tokens": 32, }, "return_logprob": return_logprob, + "top_logprobs_num": top_logprobs_num, + "return_text_in_logprobs": return_text, "logprob_start_len": 0, }, ) - print(response.json()) + print(json.dumps(response.json())) + print("=" * 100) if __name__ == "__main__": @@ -37,5 +41,8 @@ if __name__ == "__main__": url = f"{args.host}:{args.port}" - test_decode(url, False) - test_decode(url, True) + test_decode(url, False, 0, False) + test_decode(url, True, 0, False) + test_decode(url, True, 0, True) + test_decode(url, True, 3, False) + test_decode(url, True, 3, True) diff --git a/test/srt/test_httpserver_decode_stream.py b/test/srt/test_httpserver_decode_stream.py index a6d29c548..7c2b5da1e 100644 --- a/test/srt/test_httpserver_decode_stream.py +++ b/test/srt/test_httpserver_decode_stream.py @@ -13,7 +13,7 @@ import json import requests -def test_decode_stream(url, return_logprob): +def test_decode_stream(url, return_logprob, top_logprobs_num): response = requests.post( url + "/generate", json={ @@ -24,6 +24,8 @@ def test_decode_stream(url, return_logprob): }, "stream": True, "return_logprob": return_logprob, + "top_logprobs_num": top_logprobs_num, + "return_text_in_logprobs": True, }, stream=True, ) @@ -37,19 +39,20 @@ def test_decode_stream(url, return_logprob): data = json.loads(chunk[5:].strip("\n")) if return_logprob: - assert data["meta_info"]["prompt_logprob"] is not None - assert data["meta_info"]["token_logprob"] is not None + assert data["meta_info"]["prefill_token_logprobs"] is not None + assert data["meta_info"]["decode_token_logprobs"] is not None assert data["meta_info"]["normalized_prompt_logprob"] is not None - if prev == 0: # Skip prompt logprobs - prev = data["meta_info"]["prompt_tokens"] - for token_txt, _, logprob in data["meta_info"]["token_logprob"][prev:]: - print(f"{token_txt}\t{logprob}", flush=True) - prev = len(data["meta_info"]["token_logprob"]) + for logprob, token_id, token_text in data["meta_info"][ + "decode_token_logprobs" + ][prev:]: + print(f"{token_text:12s}\t{logprob}\t{token_id}", flush=True) + prev = len(data["meta_info"]["decode_token_logprobs"]) else: output = data["text"].strip() print(output[prev:], end="", flush=True) prev = len(output) - print("") + + print("=" * 100) if __name__ == "__main__": @@ -60,5 +63,6 @@ if __name__ == "__main__": url = f"{args.host}:{args.port}" - test_decode_stream(url, False) - test_decode_stream(url, True) + test_decode_stream(url, False, 0) + test_decode_stream(url, True, 0) + test_decode_stream(url, True, 3) diff --git a/test/srt/test_openai_server.py b/test/srt/test_openai_server.py index 377aa01ce..4cc50af85 100644 --- a/test/srt/test_openai_server.py +++ b/test/srt/test_openai_server.py @@ -34,6 +34,7 @@ def test_completion(args, echo, logprobs): if echo: assert text.startswith("The capital of France is") if logprobs: + print(response.choices[0].logprobs.top_logprobs) assert response.choices[0].logprobs if echo: assert response.choices[0].logprobs.token_logprobs[0] == None @@ -44,6 +45,7 @@ def test_completion(args, echo, logprobs): assert response.usage.prompt_tokens > 0 assert response.usage.completion_tokens > 0 assert response.usage.total_tokens > 0 + print("=" * 100) def test_completion_stream(args, echo, logprobs): @@ -68,13 +70,14 @@ def test_completion_stream(args, echo, logprobs): f"{r.choices[0].text:12s}\t" f"{r.choices[0].logprobs.token_logprobs}", flush=True, ) + print(r.choices[0].logprobs.top_logprobs) else: print(r.choices[0].text, end="", flush=True) assert r.id assert r.usage.prompt_tokens > 0 assert r.usage.completion_tokens > 0 assert r.usage.total_tokens > 0 - print() + print("=" * 100) def test_chat_completion(args): @@ -94,6 +97,7 @@ def test_chat_completion(args): assert response.usage.prompt_tokens > 0 assert response.usage.completion_tokens > 0 assert response.usage.total_tokens > 0 + print("=" * 100) def test_chat_completion_image(args): @@ -124,6 +128,7 @@ def test_chat_completion_image(args): assert response.usage.prompt_tokens > 0 assert response.usage.completion_tokens > 0 assert response.usage.total_tokens > 0 + print("=" * 100) def test_chat_completion_stream(args): @@ -149,7 +154,7 @@ def test_chat_completion_stream(args): if not data.content: continue print(data.content, end="", flush=True) - print() + print("=" * 100) def test_regex(args): @@ -174,6 +179,7 @@ def test_regex(args): ) text = response.choices[0].message.content print(json.loads(text)) + print("=" * 100) if __name__ == "__main__": @@ -188,10 +194,14 @@ if __name__ == "__main__": test_completion(args, echo=True, logprobs=False) test_completion(args, echo=False, logprobs=True) test_completion(args, echo=True, logprobs=True) + test_completion(args, echo=False, logprobs=3) + test_completion(args, echo=True, logprobs=3) test_completion_stream(args, echo=False, logprobs=False) test_completion_stream(args, echo=True, logprobs=False) test_completion_stream(args, echo=False, logprobs=True) test_completion_stream(args, echo=True, logprobs=True) + test_completion_stream(args, echo=False, logprobs=3) + test_completion_stream(args, echo=True, logprobs=3) test_chat_completion(args) test_chat_completion_stream(args) test_regex(args)