From b997a18d74213e905052c47941eebefd36a4d276 Mon Sep 17 00:00:00 2001 From: yichuan~ <73766326+yichuan520030910320@users.noreply.github.com> Date: Sun, 18 Aug 2024 23:45:41 -0700 Subject: [PATCH] [Feat]Add support for optional start len of logprobs (#1035) Co-authored-by: Ying Sheng Co-authored-by: Yineng Zhang Co-authored-by: Lianmin Zheng Co-authored-by: Liangsheng Yin --- python/sglang/srt/layers/logits_processor.py | 61 +++++++++++++++---- python/sglang/srt/managers/io_struct.py | 4 +- .../sglang/srt/managers/tokenizer_manager.py | 7 +++ .../srt/model_executor/forward_batch_info.py | 18 +++++- python/sglang/srt/openai_api/adapter.py | 44 +++++++++---- python/sglang/srt/server.py | 2 + python/sglang/test/runners.py | 1 + test/srt/test_openai_server.py | 7 +-- 8 files changed, 113 insertions(+), 31 deletions(-) diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 2e0ce6d5c..a5ba06de0 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -55,6 +55,9 @@ class LogitsMetadata: extend_start_loc: Optional[torch.Tensor] = None top_logprobs_nums: Optional[List[int]] = None + extend_seq_lens_cpu: List[int] = None + logprob_start_lens_cpu: List[int] = None + @classmethod def from_input_metadata(cls, input_metadata: InputMetadata): return cls( @@ -63,6 +66,8 @@ class LogitsMetadata: extend_start_loc=input_metadata.extend_start_loc, return_logprob=input_metadata.return_logprob, top_logprobs_nums=input_metadata.top_logprobs_nums, + extend_seq_lens_cpu=input_metadata.extend_seq_lens_cpu, + logprob_start_lens_cpu=input_metadata.logprob_start_lens_cpu, ) @@ -75,12 +80,16 @@ class LogitsProcessor(nn.Module): ) def _get_normalized_prompt_logprobs( - self, input_token_logprobs, logits_metadata: LogitsMetadata + self, + input_token_logprobs: torch.Tensor, + cum_start_len0: torch.Tensor, + cum_start_len1: torch.Tensor, + logits_metadata: LogitsMetadata, ): logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32) - start = logits_metadata.extend_start_loc.clone() - end = start + logits_metadata.extend_seq_lens - 2 + start = logits_metadata.extend_start_loc.clone() - cum_start_len0 + end = start + logits_metadata.extend_seq_lens - 2 - cum_start_len1 start.clamp_(min=0, max=input_token_logprobs.shape[0] - 1) end.clamp_(min=0, max=input_token_logprobs.shape[0] - 1) sum_logp = ( @@ -93,7 +102,7 @@ class LogitsProcessor(nn.Module): return normalized_prompt_logprobs @staticmethod - def get_top_logprobs(all_logprobs, logits_metadata: LogitsMetadata): + def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata): if logits_metadata.forward_mode == ForwardMode.DECODE: output_top_logprobs = [] max_k = max(logits_metadata.top_logprobs_nums) @@ -107,7 +116,7 @@ class LogitsProcessor(nn.Module): # TODO: vectorize the code below input_top_logprobs, output_top_logprobs = [], [] pt = 0 - extend_seq_lens_cpu = logits_metadata.extend_seq_lens.tolist() + extend_seq_lens_cpu = logits_metadata.extend_seq_lens_cpu max_k = max(logits_metadata.top_logprobs_nums) ret = all_logprobs.topk(max_k, dim=1) @@ -115,26 +124,30 @@ class LogitsProcessor(nn.Module): indices = ret.indices.tolist() for i, extend_seq_len in enumerate(extend_seq_lens_cpu): + start_len = logits_metadata.logprob_start_lens_cpu[i] + pruned_len = extend_seq_len - start_len + if extend_seq_len == 0: input_top_logprobs.append([]) output_top_logprobs.append([]) continue + k = logits_metadata.top_logprobs_nums[i] input_top_logprobs.append( [ list(zip(values[pt + j][:k], indices[pt + j][:k])) - for j in range(extend_seq_len - 1) + for j in range(pruned_len - 1) ] ) output_top_logprobs.append( list( zip( - values[pt + extend_seq_len - 1][:k], - indices[pt + extend_seq_len - 1][:k], + values[pt + pruned_len - 1][:k], + indices[pt + pruned_len - 1][:k], ) ) ) - pt += extend_seq_len + pt += pruned_len return input_top_logprobs, output_top_logprobs @@ -205,7 +218,23 @@ class LogitsProcessor(nn.Module): output_top_logprobs=output_top_logprobs, ) else: - all_logits = torch.matmul(hidden_states, weight.T) + pt, states, pruned_input_ids = 0, [], [] + for i, extend_len in enumerate(logits_metadata.extend_seq_lens_cpu): + start_len = logits_metadata.logprob_start_lens_cpu[i] + states.append(hidden_states[pt + start_len : pt + extend_len]) + pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len]) + pt += extend_len + + states = torch.cat(states, dim=0) + pruned_input_ids = torch.cat(pruned_input_ids, dim=0) + + cum_start_len1 = torch.tensor( + logits_metadata.logprob_start_lens_cpu, device="cuda" + ).cumsum(0) + cum_start_len0 = torch.zeros_like(cum_start_len1) + cum_start_len0[1:] = cum_start_len1[:-1] + + all_logits = torch.matmul(states, weight.T) if self.do_tensor_parallel_all_gather: all_logits = tensor_model_parallel_all_gather(all_logits) all_logits = all_logits[:, : self.config.vocab_size].float() @@ -230,19 +259,25 @@ class LogitsProcessor(nn.Module): else: input_top_logprobs = output_top_logprobs = None - last_logprobs = all_logprobs[last_index] + last_logprobs = all_logprobs[last_index - cum_start_len1] # Compute the logprobs and normalized logprobs for the prefill tokens. # Note that we pad a zero at the end of each sequence for easy computation. input_token_logprobs = all_logprobs[ torch.arange(all_logprobs.shape[0], device="cuda"), - torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]), + torch.cat([pruned_input_ids[1:], torch.tensor([0], device="cuda")]), ] normalized_prompt_logprobs = self._get_normalized_prompt_logprobs( - input_token_logprobs, logits_metadata + input_token_logprobs, + cum_start_len0, + cum_start_len1, + logits_metadata, ) + # Remove the last token logprob for the prefill tokens. + input_token_logprobs = input_token_logprobs[:-1] + return LogitProcessorOutput( next_token_logits=last_logits, next_token_logprobs=last_logprobs, diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 82f280b60..3a0ecd8f6 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -75,7 +75,7 @@ class GenerateReqInput: if self.return_logprob is None: self.return_logprob = False if self.logprob_start_len is None: - self.logprob_start_len = 0 + self.logprob_start_len = -1 if self.top_logprobs_num is None: self.top_logprobs_num = 0 else: @@ -141,7 +141,7 @@ class GenerateReqInput: self.return_logprob = [self.return_logprob] * num if self.logprob_start_len is None: - self.logprob_start_len = [0] * num + self.logprob_start_len = [-1] * num elif not isinstance(self.logprob_start_len, list): self.logprob_start_len = [self.logprob_start_len] * num diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index d5fbfe05d..edbfff3ec 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -195,6 +195,9 @@ class TokenizerManager: if not_use_index else obj.logprob_start_len[index] ) + if return_logprob and logprob_start_len == -1: + logprob_start_len = len(input_ids) - 1 + top_logprobs_num = ( obj.top_logprobs_num if not_use_index @@ -245,6 +248,8 @@ class TokenizerManager: top_logprobs_num = obj.top_logprobs_num[0] if self.is_generation: + if return_logprob and logprob_start_len == -1: + logprob_start_len = len(input_ids) - 1 tokenized_obj = TokenizedGenerateReqInput( rid, input_text, @@ -334,6 +339,8 @@ class TokenizerManager: sampling_params = self._get_sampling_params(obj.sampling_params[index]) if self.is_generation: + if obj.return_logprob[index] and obj.logprob_start_len[index] == -1: + obj.logprob_start_len[index] = len(input_ids) - 1 pixel_values, image_hash, image_size = await self._get_pixel_values( obj.image_data[index] ) diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 3cf68eab2..bac0a0537 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -61,9 +61,11 @@ class InputMetadata: extend_start_loc: torch.Tensor = None extend_no_prefix: bool = None - # Output options + # For logprob return_logprob: bool = False top_logprobs_nums: List[int] = None + extend_seq_lens_cpu: List[int] = None + logprob_start_lens_cpu: List[int] = None # For multimodal pixel_values: List[torch.Tensor] = None @@ -139,6 +141,7 @@ class InputMetadata: def compute_extend_infos(self, batch: ScheduleBatch): if self.forward_mode == ForwardMode.DECODE: self.extend_seq_lens = self.extend_start_loc = self.extend_no_prefix = None + self.extend_seq_lens_cpu = self.logprob_start_lens_cpu = None else: extend_lens_cpu = [ len(r.fill_ids) - batch.prefix_lens_cpu[i] @@ -149,6 +152,19 @@ class InputMetadata: self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0) self.extend_no_prefix = all(l == 0 for l in batch.prefix_lens_cpu) + self.extend_seq_lens_cpu = extend_lens_cpu + self.logprob_start_lens_cpu = [ + ( + min( + req.logprob_start_len - batch.prefix_lens_cpu[i], + extend_lens_cpu[i] - 1, + ) + if req.logprob_start_len >= batch.prefix_lens_cpu[i] + else extend_lens_cpu[i] - 1 # Fake extend, actually decode + ) + for i, req in enumerate(batch.reqs) + ] + @classmethod def from_schedule_batch( cls, diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index 15aa701cb..5d7bb7af7 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -20,6 +20,7 @@ import json import os import time import uuid +import warnings from http import HTTPStatus from typing import Dict, List, Optional @@ -383,20 +384,33 @@ async def v1_retrieve_file_content(file_id: str): return StreamingResponse(iter_file(), media_type="application/octet-stream") -def v1_generate_request(all_requests): +def v1_generate_request(all_requests: List[CompletionRequest]): prompts = [] sampling_params_list = [] return_logprobs = [] + logprob_start_lens = [] top_logprobs_nums = [] + + # NOTE: with openai API, the prompt's logprobs are always not computed first_prompt_type = type(all_requests[0].prompt) + for request in all_requests: + assert ( + type(request.prompt) == first_prompt_type + ), "All prompts must be of the same type in file input settings" + if len(all_requests) > 1 and request.n > 1: + raise ValueError( + "Parallel sampling is not supported for completions from files" + ) + if request.echo and request.logprobs: + warnings.warn( + "Echo is not compatible with logprobs. " + "To compute logprobs of input prompt, please use SGLang /request API." + ) for request in all_requests: - prompt = request.prompt - assert ( - type(prompt) == first_prompt_type - ), "All prompts must be of the same type in file input settings" - prompts.append(prompt) + prompts.append(request.prompt) return_logprobs.append(request.logprobs is not None and request.logprobs > 0) + logprob_start_lens.append(-1) top_logprobs_nums.append( request.logprobs if request.logprobs is not None else 0 ) @@ -416,14 +430,11 @@ def v1_generate_request(all_requests): "ignore_eos": request.ignore_eos, } ) - if len(all_requests) > 1 and request.n > 1: - raise ValueError( - "Parallel sampling is not supported for completions from files" - ) if len(all_requests) == 1: prompt = prompts[0] sampling_params_list = sampling_params_list[0] + logprob_start_lens = logprob_start_lens[0] return_logprobs = return_logprobs[0] top_logprobs_nums = top_logprobs_nums[0] if isinstance(prompt, str) or isinstance(prompt[0], str): @@ -441,6 +452,7 @@ def v1_generate_request(all_requests): sampling_params=sampling_params_list, return_logprob=return_logprobs, top_logprobs_num=top_logprobs_nums, + logprob_start_len=logprob_start_lens, return_text_in_logprobs=True, stream=all_requests[0].stream, ) @@ -694,12 +706,18 @@ async def v1_completions(tokenizer_manager, raw_request: Request): return response -def v1_chat_generate_request(all_requests, tokenizer_manager): +def v1_chat_generate_request( + all_requests: List[ChatCompletionRequest], tokenizer_manager +): input_ids = [] sampling_params_list = [] image_data_list = [] return_logprobs = [] + logprob_start_lens = [] top_logprobs_nums = [] + + # NOTE: with openai API, the prompt's logprobs are always not computed + for request in all_requests: # Prep the data needed for the underlying GenerateReqInput: # - prompt: The full prompt string. @@ -732,6 +750,7 @@ def v1_chat_generate_request(all_requests, tokenizer_manager): image_data = None input_ids.append(prompt_ids) return_logprobs.append(request.logprobs) + logprob_start_lens.append(-1) top_logprobs_nums.append(request.top_logprobs) sampling_params_list.append( { @@ -758,17 +777,20 @@ def v1_chat_generate_request(all_requests, tokenizer_manager): sampling_params_list = sampling_params_list[0] image_data = image_data_list[0] return_logprobs = return_logprobs[0] + logprob_start_lens = logprob_start_lens[0] top_logprobs_nums = top_logprobs_nums[0] else: if isinstance(input_ids[0], str): prompt_kwargs = {"text": input_ids} else: prompt_kwargs = {"input_ids": input_ids} + adapted_request = GenerateReqInput( **prompt_kwargs, image_data=image_data, sampling_params=sampling_params_list, return_logprob=return_logprobs, + logprob_start_len=logprob_start_lens, top_logprobs_num=top_logprobs_nums, stream=all_requests[0].stream, return_text_in_logprobs=True, diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 9028c1230..997b805cc 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -559,12 +559,14 @@ class Runtime: prompt: str, sampling_params: Optional[Dict] = None, return_logprob: Optional[Union[List[bool], bool]] = False, + logprob_start_len: Optional[Union[List[int], int]] = None, top_logprobs_num: Optional[Union[List[int], int]] = None, ): json_data = { "text": prompt, "sampling_params": sampling_params, "return_logprob": return_logprob, + "logprob_start_len": logprob_start_len, "top_logprobs_num": top_logprobs_num, } response = requests.post( diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index e325ecb71..9386d7f7a 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -209,6 +209,7 @@ class SRTRunner: prompt, sampling_params=sampling_params, return_logprob=True, + logprob_start_len=0, top_logprobs_num=NUM_TOP_LOGPROBS, ) response = json.loads(response) diff --git a/test/srt/test_openai_server.py b/test/srt/test_openai_server.py index 872424756..c62fefe9f 100644 --- a/test/srt/test_openai_server.py +++ b/test/srt/test_openai_server.py @@ -70,13 +70,12 @@ class TestOpenAIServer(unittest.TestCase): assert isinstance(response.choices[0].logprobs.tokens[0], str) assert isinstance(response.choices[0].logprobs.top_logprobs[1], dict) ret_num_top_logprobs = len(response.choices[0].logprobs.top_logprobs[1]) + # FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some out_put id maps to the same output token and duplicate in the map # assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}" + assert ret_num_top_logprobs > 0 - if echo: - assert response.choices[0].logprobs.token_logprobs[0] == None - else: - assert response.choices[0].logprobs.token_logprobs[0] != None + assert response.choices[0].logprobs.token_logprobs[0] != None assert response.id assert response.created