From b146555749f84a684c7cf5e9d2950ca474b82de2 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Mon, 19 May 2025 18:21:29 -0700 Subject: [PATCH] Revert "Implement `return_hidden_states` for the OpenAI API (#6137)" (#6440) --- python/sglang/srt/openai_api/adapter.py | 88 +--------- python/sglang/srt/openai_api/protocol.py | 31 +--- test/srt/test_openai_server.py | 209 ++++++----------------- 3 files changed, 53 insertions(+), 275 deletions(-) diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index 14b521740..9f54641d9 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -531,7 +531,6 @@ def v1_generate_request( logprob_start_lens = [] top_logprobs_nums = [] lora_paths = [] - return_hidden_states = [] for request in all_requests: # NOTE: with openai API, the prompt's logprobs are always not computed @@ -578,7 +577,6 @@ def v1_generate_request( top_logprobs_nums.append( request.logprobs if request.logprobs is not None else 0 ) - return_hidden_states.append(request.return_hidden_states) if len(all_requests) == 1: if isinstance(prompts[0], str) or isinstance(prompts[0][0], str): @@ -590,7 +588,6 @@ def v1_generate_request( logprob_start_lens = logprob_start_lens[0] top_logprobs_nums = top_logprobs_nums[0] lora_paths = lora_paths[0] - return_hidden_states = return_hidden_states[0] else: if isinstance(prompts[0], str) or isinstance(prompts[0][0], str): prompt_kwargs = {"text": prompts} @@ -607,7 +604,6 @@ def v1_generate_request( stream=all_requests[0].stream, rid=request_ids, lora_path=lora_paths, - return_hidden_states=return_hidden_states, ) return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0] @@ -673,17 +669,6 @@ def v1_generate_response( else: logprobs = None - hidden_states = None - if isinstance(request, list) and request[idx].return_hidden_states: - hidden_states = ret_item["meta_info"].get("hidden_states", None) - elif (not isinstance(request, list)) and request.return_hidden_states: - hidden_states = ret_item["meta_info"].get("hidden_states", None) - if hidden_states is not None: - hidden_states = hidden_states[1:] # trim off the prefill - hidden_states = ( - hidden_states[-1] if len(hidden_states) > 0 else [] - ) # slice out the last token - finish_reason = ret_item["meta_info"]["finish_reason"] if to_file: @@ -710,7 +695,6 @@ def v1_generate_response( if finish_reason and "matched" in finish_reason else None ), - hidden_states=hidden_states, ) choices.append(choice_data) @@ -735,7 +719,6 @@ def v1_generate_response( + ret[i]["meta_info"]["completion_tokens"], }, "system_fingerprint": None, - "hidden_states": hidden_states, }, } responses.append(response) @@ -780,7 +763,6 @@ async def v1_completions(tokenizer_manager, raw_request: Request): prompt_tokens = {} completion_tokens = {} cached_tokens = {} - hidden_states = None try: async for content in tokenizer_manager.generate_request( @@ -795,9 +777,6 @@ async def v1_completions(tokenizer_manager, raw_request: Request): prompt_tokens[index] = content["meta_info"]["prompt_tokens"] completion_tokens[index] = content["meta_info"]["completion_tokens"] cached_tokens[index] = content["meta_info"].get("cached_tokens", 0) - hidden_states = ( - content["meta_info"].get("hidden_states", None) or hidden_states - ) if not stream_buffer: # The first chunk if request.echo: @@ -903,25 +882,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request): total_tokens=total_prompt_tokens + total_completion_tokens, prompt_tokens_details=prompt_tokens_details, ) - if request.return_hidden_states and hidden_states: - hidden_states = hidden_states[1:] # trim off the prefill - hidden_states = ( - hidden_states[-1] if len(hidden_states) > 0 else [] - ) # slice out the last token - hidden_states_chunk = CompletionStreamResponse( - id=content["meta_info"]["id"], - created=created, - choices=[ - CompletionResponseStreamChoice( - text="", - index=index, - hidden_states=hidden_states, - finish_reason=None, - ) - ], - model=request.model, - ) - yield f"data: {hidden_states_chunk.model_dump_json()}\n\n" + final_usage_chunk = CompletionStreamResponse( id=content["meta_info"]["id"], created=created, @@ -998,7 +959,6 @@ def v1_chat_generate_request( top_logprobs_nums = [] modalities_list = [] lora_paths = [] - return_hidden_states = [] # NOTE: with openai API, the prompt's logprobs are always not computed @@ -1216,7 +1176,6 @@ def v1_chat_generate_request( image_data_list.append(image_data) audio_data_list.append(audio_data) modalities_list.append(modalities) - return_hidden_states.append(request.return_hidden_states) if len(all_requests) == 1: if is_multimodal: # processor will need text input @@ -1235,7 +1194,6 @@ def v1_chat_generate_request( modalities_list = modalities_list[0] lora_paths = lora_paths[0] request_ids = request_ids[0] - return_hidden_states = return_hidden_states[0] else: if tokenizer_manager.model_config.is_multimodal: # processor will need text input @@ -1262,7 +1220,6 @@ def v1_chat_generate_request( bootstrap_host=all_requests[0].bootstrap_host, bootstrap_port=all_requests[0].bootstrap_port, bootstrap_room=all_requests[0].bootstrap_room, - return_hidden_states=return_hidden_states, ) return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0] @@ -1323,21 +1280,6 @@ def v1_chat_generate_response( else: choice_logprobs = None - if isinstance(request, list) and request[idx].return_hidden_states: - include_hidden_states = True - elif not isinstance(request, list) and request.return_hidden_states: - include_hidden_states = True - else: - include_hidden_states = False - if include_hidden_states and ret_item["meta_info"].get("hidden_states", None): - hidden_states = ret_item["meta_info"]["hidden_states"] - hidden_states = hidden_states[1:] # trim off the prefill - hidden_states = ( - hidden_states[-1] if len(hidden_states) > 0 else [] - ) # slice out the last token - else: - hidden_states = None - finish_reason = ret_item["meta_info"]["finish_reason"] tool_calls = None @@ -1402,7 +1344,6 @@ def v1_chat_generate_response( "content": text if text else None, "tool_calls": tool_calls, "reasoning_content": reasoning_text if reasoning_text else None, - "hidden_states": hidden_states, }, "logprobs": choice_logprobs.model_dump() if choice_logprobs else None, "finish_reason": finish_reason["type"] if finish_reason else None, @@ -1428,7 +1369,6 @@ def v1_chat_generate_response( if finish_reason and "matched" in finish_reason else None ), - hidden_states=hidden_states, ) choices.append(choice_data) @@ -1497,7 +1437,6 @@ async def v1_chat_completions( if adapted_request.stream: parser_dict = {} reasoning_parser_dict = {} - hidden_states = None async def generate_stream_resp(): tool_call_first = True @@ -1507,16 +1446,12 @@ async def v1_chat_completions( prompt_tokens = {} completion_tokens = {} cached_tokens = {} - hidden_states = None try: async for content in tokenizer_manager.generate_request( adapted_request, raw_request ): index = content.get("index", 0) text = content["text"] - hidden_states = ( - content["meta_info"].get("hidden_states", None) or hidden_states - ) is_first = is_firsts.get(index, True) stream_buffer = stream_buffers.get(index, "") @@ -1638,7 +1573,6 @@ async def v1_chat_completions( if (delta and len(delta) == 0) or not delta: stream_buffers[index] = new_stream_buffer is_firsts[index] = is_first - n_prev_tokens[index] = n_prev_token continue if request.tool_choice != "none" and request.tools: @@ -1727,7 +1661,6 @@ async def v1_chat_completions( stream_buffers[index] = new_stream_buffer is_firsts[index] = is_first - n_prev_tokens[index] = n_prev_token else: # No tool calls => just treat this as normal text @@ -1760,7 +1693,6 @@ async def v1_chat_completions( yield f"data: {chunk.model_dump_json()}\n\n" stream_buffers[index] = new_stream_buffer is_firsts[index] = is_first - n_prev_tokens[index] = n_prev_token if finish_reason_type == "stop" and request.tool_choice != "none": parser = FunctionCallParser( tools=request.tools, @@ -1796,24 +1728,6 @@ async def v1_chat_completions( else: usage = None - if request.return_hidden_states and hidden_states: - hidden_states = hidden_states[1:] # trim off the prefill - hidden_states = ( - hidden_states[-1] if len(hidden_states) > 0 else [] - ) # slice out the last token - hidden_states_chunk = ChatCompletionStreamResponse( - id=content["meta_info"]["id"], - created=created, - choices=[ - ChatCompletionResponseStreamChoice( - index=index, - delta=DeltaMessage(hidden_states=hidden_states), - finish_reason=finish_reason_type, - ) - ], - model=request.model, - ) - yield f"data: {hidden_states_chunk.model_dump_json()}\n\n" final_usage_chunk = ChatCompletionStreamResponse( id=content["meta_info"]["id"], created=created, diff --git a/python/sglang/srt/openai_api/protocol.py b/python/sglang/srt/openai_api/protocol.py index 3938ba25a..7c40a70dc 100644 --- a/python/sglang/srt/openai_api/protocol.py +++ b/python/sglang/srt/openai_api/protocol.py @@ -16,7 +16,7 @@ import time from typing import Dict, List, Optional, Union -from pydantic import BaseModel, Field, model_serializer, root_validator +from pydantic import BaseModel, Field, root_validator from typing_extensions import Literal @@ -182,7 +182,6 @@ class CompletionRequest(BaseModel): skip_special_tokens: bool = True lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None session_params: Optional[Dict] = None - return_hidden_states: Optional[bool] = False class CompletionResponseChoice(BaseModel): @@ -191,11 +190,6 @@ class CompletionResponseChoice(BaseModel): logprobs: Optional[LogProbs] = None finish_reason: Literal["stop", "length", "content_filter"] matched_stop: Union[None, int, str] = None - hidden_states: Optional[object] = None - - @model_serializer - def _serialize(self): - return exclude_if_none(self, ["hidden_states"]) class CompletionResponse(BaseModel): @@ -213,11 +207,6 @@ class CompletionResponseStreamChoice(BaseModel): logprobs: Optional[LogProbs] = None finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None matched_stop: Union[None, int, str] = None - hidden_states: Optional[object] = None - - @model_serializer - def _serialize(self): - return exclude_if_none(self, ["hidden_states"]) class CompletionStreamResponse(BaseModel): @@ -411,9 +400,6 @@ class ChatCompletionRequest(BaseModel): bootstrap_port: Optional[int] = None bootstrap_room: Optional[int] = None - # Hidden States - return_hidden_states: Optional[bool] = False - class ChatMessage(BaseModel): role: Optional[str] = None @@ -430,11 +416,6 @@ class ChatCompletionResponseChoice(BaseModel): "stop", "length", "tool_calls", "content_filter", "function_call" ] matched_stop: Union[None, int, str] = None - hidden_states: Optional[object] = None - - @model_serializer - def _serialize(self): - return exclude_if_none(self, ["hidden_states"]) class ChatCompletionResponse(BaseModel): @@ -451,11 +432,6 @@ class DeltaMessage(BaseModel): content: Optional[str] = None reasoning_content: Optional[str] = None tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None]) - hidden_states: Optional[object] = None - - @model_serializer - def _serialize(self): - return exclude_if_none(self, ["hidden_states"]) class ChatCompletionResponseStreamChoice(BaseModel): @@ -508,8 +484,3 @@ class EmbeddingResponse(BaseModel): model: str object: str = "list" usage: Optional[UsageInfo] = None - - -def exclude_if_none(obj, field_names: List[str]): - omit_if_none_fields = {k for k, v in obj.model_fields.items() if k in field_names} - return {k: v for k, v in obj if k not in omit_if_none_fields or v is not None} diff --git a/test/srt/test_openai_server.py b/test/srt/test_openai_server.py index 8f827bb4a..ea295ad74 100644 --- a/test/srt/test_openai_server.py +++ b/test/srt/test_openai_server.py @@ -1,9 +1,7 @@ """ python3 -m unittest test_openai_server.TestOpenAIServer.test_batch python3 -m unittest test_openai_server.TestOpenAIServer.test_completion -python3 -m unittest test_openai_server.TestOpenAIServer.test_completion_stream -python3 -m unittest test_openai_server.TestOpenAIServer.test_chat_completion -python3 -m unittest test_openai_server.TestOpenAIServer.test_chat_completion_stream + """ import json @@ -11,7 +9,6 @@ import re import time import unittest -import numpy as np import openai from sglang.srt.hf_transformers_utils import get_tokenizer @@ -46,13 +43,7 @@ class TestOpenAIServer(CustomTestCase): kill_process_tree(cls.process.pid) def run_completion( - self, - echo, - logprobs, - use_list_input, - parallel_sample_num, - token_input, - return_hidden_states, + self, echo, logprobs, use_list_input, parallel_sample_num, token_input ): client = openai.Client(api_key=self.api_key, base_url=self.base_url) prompt = "The capital of France is" @@ -79,7 +70,6 @@ class TestOpenAIServer(CustomTestCase): echo=echo, logprobs=logprobs, n=parallel_sample_num, - extra_body=dict(return_hidden_states=return_hidden_states), ) assert len(response.choices) == num_choices * parallel_sample_num @@ -110,26 +100,8 @@ class TestOpenAIServer(CustomTestCase): assert response.usage.completion_tokens > 0 assert response.usage.total_tokens > 0 - if return_hidden_states: - hidden_states = response.choices[0].hidden_states - assert hidden_states is not None, "hidden_states was none" - hidden_states = np.asarray(hidden_states) - assert ( - len(hidden_states.shape) == 1 - ), f"hidden_states shape is not correct, was {hidden_states.shape}" - else: - assert not hasattr( - response.choices[0], "hidden_states" - ), "hidden_states was returned and should not have been" - def run_completion_stream( - self, - echo, - logprobs, - use_list_input, - parallel_sample_num, - token_input, - return_hidden_states, + self, echo, logprobs, use_list_input, parallel_sample_num, token_input ): client = openai.Client(api_key=self.api_key, base_url=self.base_url) prompt = "The capital of France is" @@ -158,44 +130,33 @@ class TestOpenAIServer(CustomTestCase): stream=True, stream_options={"include_usage": True}, n=parallel_sample_num, - extra_body=dict(return_hidden_states=return_hidden_states), ) is_firsts = {} - hidden_states = None for response in generator: usage = response.usage if usage is not None: - assert usage.prompt_tokens > 0, f"usage.prompt_tokens was zero" - assert usage.completion_tokens > 0, f"usage.completion_tokens was zero" - assert usage.total_tokens > 0, f"usage.total_tokens was zero" - continue - - if ( - hasattr(response.choices[0], "hidden_states") - and response.choices[0].hidden_states is not None - ): - hidden_states = response.choices[0].hidden_states + assert usage.prompt_tokens > 0 + assert usage.completion_tokens > 0 + assert usage.total_tokens > 0 continue index = response.choices[0].index is_first = is_firsts.get(index, True) if logprobs: - assert response.choices[0].logprobs, f"no logprobs in response" - assert isinstance( - response.choices[0].logprobs.tokens[0], str - ), f"{response.choices[0].logprobs.tokens[0]} is not a string" + assert response.choices[0].logprobs + assert isinstance(response.choices[0].logprobs.tokens[0], str) if not (is_first and echo): assert isinstance( response.choices[0].logprobs.top_logprobs[0], dict - ), f"top_logprobs was not a dictionary" + ) ret_num_top_logprobs = len( response.choices[0].logprobs.top_logprobs[0] ) # FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some output id maps to the same output token and duplicate in the map # assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}" - assert ret_num_top_logprobs > 0, f"ret_num_top_logprobs was 0" + assert ret_num_top_logprobs > 0 if is_first: if echo: @@ -203,29 +164,15 @@ class TestOpenAIServer(CustomTestCase): prompt ), f"{response.choices[0].text} and all args {echo} {logprobs} {token_input} {is_first}" is_firsts[index] = False - assert response.id, f"no id in response" - assert response.created, f"no created in response" + assert response.id + assert response.created for index in [i for i in range(parallel_sample_num * num_choices)]: assert not is_firsts.get( index, True ), f"index {index} is not found in the response" - if return_hidden_states: - assert hidden_states is not None, "hidden_states is not returned" - try: - hidden_states = np.asarray(hidden_states) - except Exception as e: - raise Exception(f"Failed to convert hidden states to numpy array: {e}") - assert ( - len(hidden_states.shape) == 1 - ), f"hidden_states shape is not correct, was {hidden_states.shape}" - else: - assert ( - hidden_states is None - ), "hidden_states was returned and should not have been" - - def run_chat_completion(self, logprobs, parallel_sample_num, return_hidden_states): + def run_chat_completion(self, logprobs, parallel_sample_num): client = openai.Client(api_key=self.api_key, base_url=self.base_url) response = client.chat.completions.create( model=self.model, @@ -240,7 +187,6 @@ class TestOpenAIServer(CustomTestCase): logprobs=logprobs is not None and logprobs > 0, top_logprobs=logprobs, n=parallel_sample_num, - extra_body=dict(return_hidden_states=return_hidden_states), ) if logprobs: @@ -264,21 +210,7 @@ class TestOpenAIServer(CustomTestCase): assert response.usage.completion_tokens > 0 assert response.usage.total_tokens > 0 - if return_hidden_states: - hidden_states = response.choices[0].hidden_states - assert hidden_states is not None, "hidden_states is not returned" - hidden_states = np.asarray(hidden_states) - assert ( - len(hidden_states.shape) == 1 - ), f"hidden_states shape is not correct, was {hidden_states.shape}" - else: - assert not hasattr( - response.choices[0], "hidden_states" - ), "hidden_states was returned and should not have been" - - def run_chat_completion_stream( - self, logprobs, parallel_sample_num=1, return_hidden_states=False - ): + def run_chat_completion_stream(self, logprobs, parallel_sample_num=1): client = openai.Client(api_key=self.api_key, base_url=self.base_url) generator = client.chat.completions.create( model=self.model, @@ -292,55 +224,40 @@ class TestOpenAIServer(CustomTestCase): stream=True, stream_options={"include_usage": True}, n=parallel_sample_num, - extra_body=dict(return_hidden_states=return_hidden_states), ) is_firsts = {} - hidden_states = None - top_logprob_tokens = [] for response in generator: usage = response.usage if usage is not None: - assert usage.prompt_tokens > 0, f"usage.prompt_tokens was zero" - assert usage.completion_tokens > 0, f"usage.completion_tokens was zero" - assert usage.total_tokens > 0, f"usage.total_tokens was zero" - continue - - if hasattr(response.choices[0].delta, "hidden_states"): - hidden_states = response.choices[0].delta.hidden_states + assert usage.prompt_tokens > 0 + assert usage.completion_tokens > 0 + assert usage.total_tokens > 0 continue index = response.choices[0].index data = response.choices[0].delta if is_firsts.get(index, True): - assert ( - data.role == "assistant" - ), f"data.role was not 'assistant' for first chunk" + assert data.role == "assistant" is_firsts[index] = False continue if logprobs: - assert response.choices[0].logprobs, f"logprobs was not returned" + assert response.choices[0].logprobs assert isinstance( response.choices[0].logprobs.content[0].top_logprobs[0].token, str - ), f"top_logprobs token was not a string" + ) assert isinstance( response.choices[0].logprobs.content[0].top_logprobs, list - ), f"top_logprobs was not a list" + ) ret_num_top_logprobs = len( response.choices[0].logprobs.content[0].top_logprobs ) assert ( ret_num_top_logprobs == logprobs ), f"{ret_num_top_logprobs} vs {logprobs}" - top_logprob_tokens.append( - response.choices[0].logprobs.content[0].top_logprobs[0].token - ) - assert ( - len(top_logprob_tokens) <= 2 or len(set(top_logprob_tokens)) > 1 - ), "Top Logprob tokens should not consistent of the same token repeated" assert ( isinstance(data.content, str) or isinstance(data.reasoning_content, str) @@ -355,20 +272,6 @@ class TestOpenAIServer(CustomTestCase): index, True ), f"index {index} is not found in the response" - if return_hidden_states: - assert hidden_states is not None, "hidden_states is not returned" - try: - hidden_states = np.asarray(hidden_states) - except Exception as e: - raise Exception(f"Failed to convert hidden states to numpy array: {e}") - assert ( - len(hidden_states.shape) == 1 - ), f"hidden_states shape is not correct, was {hidden_states.shape}" - else: - assert ( - hidden_states is None - ), "hidden_states was returned and should not have been" - def _create_batch(self, mode, client): if mode == "completion": input_file_path = "complete_input.jsonl" @@ -516,53 +419,43 @@ class TestOpenAIServer(CustomTestCase): assert del_response.deleted def test_completion(self): - for return_hidden_states in [False, True]: - for echo in [False, True]: - for logprobs in [None, 5]: - for use_list_input in [True, False]: - for parallel_sample_num in [1, 2]: - for token_input in [False, True]: - self.run_completion( - echo, - logprobs, - use_list_input, - parallel_sample_num, - token_input, - return_hidden_states, - ) + for echo in [False, True]: + for logprobs in [None, 5]: + for use_list_input in [True, False]: + for parallel_sample_num in [1, 2]: + for token_input in [False, True]: + self.run_completion( + echo, + logprobs, + use_list_input, + parallel_sample_num, + token_input, + ) def test_completion_stream(self): # parallel sampling and list input are not supported in streaming mode - for return_hidden_states in [False, True]: - for echo in [False, True]: - for logprobs in [None, 5]: - for use_list_input in [True, False]: - for parallel_sample_num in [1, 2]: - for token_input in [False, True]: - self.run_completion_stream( - echo, - logprobs, - use_list_input, - parallel_sample_num, - token_input, - return_hidden_states, - ) + for echo in [False, True]: + for logprobs in [None, 5]: + for use_list_input in [True, False]: + for parallel_sample_num in [1, 2]: + for token_input in [False, True]: + self.run_completion_stream( + echo, + logprobs, + use_list_input, + parallel_sample_num, + token_input, + ) def test_chat_completion(self): - for return_hidden_states in [False, True]: - for logprobs in [None, 5]: - for parallel_sample_num in [1, 2]: - self.run_chat_completion( - logprobs, parallel_sample_num, return_hidden_states - ) + for logprobs in [None, 5]: + for parallel_sample_num in [1, 2]: + self.run_chat_completion(logprobs, parallel_sample_num) def test_chat_completion_stream(self): - for return_hidden_states in [False, True]: - for logprobs in [None, 5]: - for parallel_sample_num in [1, 2]: - self.run_chat_completion_stream( - logprobs, parallel_sample_num, return_hidden_states - ) + for logprobs in [None, 5]: + for parallel_sample_num in [1, 2]: + self.run_chat_completion_stream(logprobs, parallel_sample_num) def test_batch(self): for mode in ["completion", "chat"]: