From 4f39bcf7ab4812ae2d7992e5718d46cf2d4e7ad1 Mon Sep 17 00:00:00 2001 From: kyle-pena-kuzco Date: Mon, 19 May 2025 01:30:25 -0400 Subject: [PATCH] Implement `return_hidden_states` for the OpenAI API (#6137) --- python/sglang/srt/openai_api/adapter.py | 88 +++++++++- python/sglang/srt/openai_api/protocol.py | 31 +++- test/srt/test_openai_server.py | 209 +++++++++++++++++------ 3 files changed, 275 insertions(+), 53 deletions(-) diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index 9f54641d9..14b521740 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -531,6 +531,7 @@ def v1_generate_request( logprob_start_lens = [] top_logprobs_nums = [] lora_paths = [] + return_hidden_states = [] for request in all_requests: # NOTE: with openai API, the prompt's logprobs are always not computed @@ -577,6 +578,7 @@ def v1_generate_request( top_logprobs_nums.append( request.logprobs if request.logprobs is not None else 0 ) + return_hidden_states.append(request.return_hidden_states) if len(all_requests) == 1: if isinstance(prompts[0], str) or isinstance(prompts[0][0], str): @@ -588,6 +590,7 @@ def v1_generate_request( logprob_start_lens = logprob_start_lens[0] top_logprobs_nums = top_logprobs_nums[0] lora_paths = lora_paths[0] + return_hidden_states = return_hidden_states[0] else: if isinstance(prompts[0], str) or isinstance(prompts[0][0], str): prompt_kwargs = {"text": prompts} @@ -604,6 +607,7 @@ def v1_generate_request( stream=all_requests[0].stream, rid=request_ids, lora_path=lora_paths, + return_hidden_states=return_hidden_states, ) return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0] @@ -669,6 +673,17 @@ def v1_generate_response( else: logprobs = None + hidden_states = None + if isinstance(request, list) and request[idx].return_hidden_states: + hidden_states = ret_item["meta_info"].get("hidden_states", None) + elif (not isinstance(request, list)) and request.return_hidden_states: + hidden_states = ret_item["meta_info"].get("hidden_states", None) + if hidden_states is not None: + hidden_states = hidden_states[1:] # trim off the prefill + hidden_states = ( + hidden_states[-1] if len(hidden_states) > 0 else [] + ) # slice out the last token + finish_reason = ret_item["meta_info"]["finish_reason"] if to_file: @@ -695,6 +710,7 @@ def v1_generate_response( if finish_reason and "matched" in finish_reason else None ), + hidden_states=hidden_states, ) choices.append(choice_data) @@ -719,6 +735,7 @@ def v1_generate_response( + ret[i]["meta_info"]["completion_tokens"], }, "system_fingerprint": None, + "hidden_states": hidden_states, }, } responses.append(response) @@ -763,6 +780,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request): prompt_tokens = {} completion_tokens = {} cached_tokens = {} + hidden_states = None try: async for content in tokenizer_manager.generate_request( @@ -777,6 +795,9 @@ async def v1_completions(tokenizer_manager, raw_request: Request): prompt_tokens[index] = content["meta_info"]["prompt_tokens"] completion_tokens[index] = content["meta_info"]["completion_tokens"] cached_tokens[index] = content["meta_info"].get("cached_tokens", 0) + hidden_states = ( + content["meta_info"].get("hidden_states", None) or hidden_states + ) if not stream_buffer: # The first chunk if request.echo: @@ -882,7 +903,25 @@ async def v1_completions(tokenizer_manager, raw_request: Request): total_tokens=total_prompt_tokens + total_completion_tokens, prompt_tokens_details=prompt_tokens_details, ) - + if request.return_hidden_states and hidden_states: + hidden_states = hidden_states[1:] # trim off the prefill + hidden_states = ( + hidden_states[-1] if len(hidden_states) > 0 else [] + ) # slice out the last token + hidden_states_chunk = CompletionStreamResponse( + id=content["meta_info"]["id"], + created=created, + choices=[ + CompletionResponseStreamChoice( + text="", + index=index, + hidden_states=hidden_states, + finish_reason=None, + ) + ], + model=request.model, + ) + yield f"data: {hidden_states_chunk.model_dump_json()}\n\n" final_usage_chunk = CompletionStreamResponse( id=content["meta_info"]["id"], created=created, @@ -959,6 +998,7 @@ def v1_chat_generate_request( top_logprobs_nums = [] modalities_list = [] lora_paths = [] + return_hidden_states = [] # NOTE: with openai API, the prompt's logprobs are always not computed @@ -1176,6 +1216,7 @@ def v1_chat_generate_request( image_data_list.append(image_data) audio_data_list.append(audio_data) modalities_list.append(modalities) + return_hidden_states.append(request.return_hidden_states) if len(all_requests) == 1: if is_multimodal: # processor will need text input @@ -1194,6 +1235,7 @@ def v1_chat_generate_request( modalities_list = modalities_list[0] lora_paths = lora_paths[0] request_ids = request_ids[0] + return_hidden_states = return_hidden_states[0] else: if tokenizer_manager.model_config.is_multimodal: # processor will need text input @@ -1220,6 +1262,7 @@ def v1_chat_generate_request( bootstrap_host=all_requests[0].bootstrap_host, bootstrap_port=all_requests[0].bootstrap_port, bootstrap_room=all_requests[0].bootstrap_room, + return_hidden_states=return_hidden_states, ) return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0] @@ -1280,6 +1323,21 @@ def v1_chat_generate_response( else: choice_logprobs = None + if isinstance(request, list) and request[idx].return_hidden_states: + include_hidden_states = True + elif not isinstance(request, list) and request.return_hidden_states: + include_hidden_states = True + else: + include_hidden_states = False + if include_hidden_states and ret_item["meta_info"].get("hidden_states", None): + hidden_states = ret_item["meta_info"]["hidden_states"] + hidden_states = hidden_states[1:] # trim off the prefill + hidden_states = ( + hidden_states[-1] if len(hidden_states) > 0 else [] + ) # slice out the last token + else: + hidden_states = None + finish_reason = ret_item["meta_info"]["finish_reason"] tool_calls = None @@ -1344,6 +1402,7 @@ def v1_chat_generate_response( "content": text if text else None, "tool_calls": tool_calls, "reasoning_content": reasoning_text if reasoning_text else None, + "hidden_states": hidden_states, }, "logprobs": choice_logprobs.model_dump() if choice_logprobs else None, "finish_reason": finish_reason["type"] if finish_reason else None, @@ -1369,6 +1428,7 @@ def v1_chat_generate_response( if finish_reason and "matched" in finish_reason else None ), + hidden_states=hidden_states, ) choices.append(choice_data) @@ -1437,6 +1497,7 @@ async def v1_chat_completions( if adapted_request.stream: parser_dict = {} reasoning_parser_dict = {} + hidden_states = None async def generate_stream_resp(): tool_call_first = True @@ -1446,12 +1507,16 @@ async def v1_chat_completions( prompt_tokens = {} completion_tokens = {} cached_tokens = {} + hidden_states = None try: async for content in tokenizer_manager.generate_request( adapted_request, raw_request ): index = content.get("index", 0) text = content["text"] + hidden_states = ( + content["meta_info"].get("hidden_states", None) or hidden_states + ) is_first = is_firsts.get(index, True) stream_buffer = stream_buffers.get(index, "") @@ -1573,6 +1638,7 @@ async def v1_chat_completions( if (delta and len(delta) == 0) or not delta: stream_buffers[index] = new_stream_buffer is_firsts[index] = is_first + n_prev_tokens[index] = n_prev_token continue if request.tool_choice != "none" and request.tools: @@ -1661,6 +1727,7 @@ async def v1_chat_completions( stream_buffers[index] = new_stream_buffer is_firsts[index] = is_first + n_prev_tokens[index] = n_prev_token else: # No tool calls => just treat this as normal text @@ -1693,6 +1760,7 @@ async def v1_chat_completions( yield f"data: {chunk.model_dump_json()}\n\n" stream_buffers[index] = new_stream_buffer is_firsts[index] = is_first + n_prev_tokens[index] = n_prev_token if finish_reason_type == "stop" and request.tool_choice != "none": parser = FunctionCallParser( tools=request.tools, @@ -1728,6 +1796,24 @@ async def v1_chat_completions( else: usage = None + if request.return_hidden_states and hidden_states: + hidden_states = hidden_states[1:] # trim off the prefill + hidden_states = ( + hidden_states[-1] if len(hidden_states) > 0 else [] + ) # slice out the last token + hidden_states_chunk = ChatCompletionStreamResponse( + id=content["meta_info"]["id"], + created=created, + choices=[ + ChatCompletionResponseStreamChoice( + index=index, + delta=DeltaMessage(hidden_states=hidden_states), + finish_reason=finish_reason_type, + ) + ], + model=request.model, + ) + yield f"data: {hidden_states_chunk.model_dump_json()}\n\n" final_usage_chunk = ChatCompletionStreamResponse( id=content["meta_info"]["id"], created=created, diff --git a/python/sglang/srt/openai_api/protocol.py b/python/sglang/srt/openai_api/protocol.py index 7c40a70dc..3938ba25a 100644 --- a/python/sglang/srt/openai_api/protocol.py +++ b/python/sglang/srt/openai_api/protocol.py @@ -16,7 +16,7 @@ import time from typing import Dict, List, Optional, Union -from pydantic import BaseModel, Field, root_validator +from pydantic import BaseModel, Field, model_serializer, root_validator from typing_extensions import Literal @@ -182,6 +182,7 @@ class CompletionRequest(BaseModel): skip_special_tokens: bool = True lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None session_params: Optional[Dict] = None + return_hidden_states: Optional[bool] = False class CompletionResponseChoice(BaseModel): @@ -190,6 +191,11 @@ class CompletionResponseChoice(BaseModel): logprobs: Optional[LogProbs] = None finish_reason: Literal["stop", "length", "content_filter"] matched_stop: Union[None, int, str] = None + hidden_states: Optional[object] = None + + @model_serializer + def _serialize(self): + return exclude_if_none(self, ["hidden_states"]) class CompletionResponse(BaseModel): @@ -207,6 +213,11 @@ class CompletionResponseStreamChoice(BaseModel): logprobs: Optional[LogProbs] = None finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None matched_stop: Union[None, int, str] = None + hidden_states: Optional[object] = None + + @model_serializer + def _serialize(self): + return exclude_if_none(self, ["hidden_states"]) class CompletionStreamResponse(BaseModel): @@ -400,6 +411,9 @@ class ChatCompletionRequest(BaseModel): bootstrap_port: Optional[int] = None bootstrap_room: Optional[int] = None + # Hidden States + return_hidden_states: Optional[bool] = False + class ChatMessage(BaseModel): role: Optional[str] = None @@ -416,6 +430,11 @@ class ChatCompletionResponseChoice(BaseModel): "stop", "length", "tool_calls", "content_filter", "function_call" ] matched_stop: Union[None, int, str] = None + hidden_states: Optional[object] = None + + @model_serializer + def _serialize(self): + return exclude_if_none(self, ["hidden_states"]) class ChatCompletionResponse(BaseModel): @@ -432,6 +451,11 @@ class DeltaMessage(BaseModel): content: Optional[str] = None reasoning_content: Optional[str] = None tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None]) + hidden_states: Optional[object] = None + + @model_serializer + def _serialize(self): + return exclude_if_none(self, ["hidden_states"]) class ChatCompletionResponseStreamChoice(BaseModel): @@ -484,3 +508,8 @@ class EmbeddingResponse(BaseModel): model: str object: str = "list" usage: Optional[UsageInfo] = None + + +def exclude_if_none(obj, field_names: List[str]): + omit_if_none_fields = {k for k, v in obj.model_fields.items() if k in field_names} + return {k: v for k, v in obj if k not in omit_if_none_fields or v is not None} diff --git a/test/srt/test_openai_server.py b/test/srt/test_openai_server.py index ea295ad74..8f827bb4a 100644 --- a/test/srt/test_openai_server.py +++ b/test/srt/test_openai_server.py @@ -1,7 +1,9 @@ """ python3 -m unittest test_openai_server.TestOpenAIServer.test_batch python3 -m unittest test_openai_server.TestOpenAIServer.test_completion - +python3 -m unittest test_openai_server.TestOpenAIServer.test_completion_stream +python3 -m unittest test_openai_server.TestOpenAIServer.test_chat_completion +python3 -m unittest test_openai_server.TestOpenAIServer.test_chat_completion_stream """ import json @@ -9,6 +11,7 @@ import re import time import unittest +import numpy as np import openai from sglang.srt.hf_transformers_utils import get_tokenizer @@ -43,7 +46,13 @@ class TestOpenAIServer(CustomTestCase): kill_process_tree(cls.process.pid) def run_completion( - self, echo, logprobs, use_list_input, parallel_sample_num, token_input + self, + echo, + logprobs, + use_list_input, + parallel_sample_num, + token_input, + return_hidden_states, ): client = openai.Client(api_key=self.api_key, base_url=self.base_url) prompt = "The capital of France is" @@ -70,6 +79,7 @@ class TestOpenAIServer(CustomTestCase): echo=echo, logprobs=logprobs, n=parallel_sample_num, + extra_body=dict(return_hidden_states=return_hidden_states), ) assert len(response.choices) == num_choices * parallel_sample_num @@ -100,8 +110,26 @@ class TestOpenAIServer(CustomTestCase): assert response.usage.completion_tokens > 0 assert response.usage.total_tokens > 0 + if return_hidden_states: + hidden_states = response.choices[0].hidden_states + assert hidden_states is not None, "hidden_states was none" + hidden_states = np.asarray(hidden_states) + assert ( + len(hidden_states.shape) == 1 + ), f"hidden_states shape is not correct, was {hidden_states.shape}" + else: + assert not hasattr( + response.choices[0], "hidden_states" + ), "hidden_states was returned and should not have been" + def run_completion_stream( - self, echo, logprobs, use_list_input, parallel_sample_num, token_input + self, + echo, + logprobs, + use_list_input, + parallel_sample_num, + token_input, + return_hidden_states, ): client = openai.Client(api_key=self.api_key, base_url=self.base_url) prompt = "The capital of France is" @@ -130,33 +158,44 @@ class TestOpenAIServer(CustomTestCase): stream=True, stream_options={"include_usage": True}, n=parallel_sample_num, + extra_body=dict(return_hidden_states=return_hidden_states), ) is_firsts = {} + hidden_states = None for response in generator: usage = response.usage if usage is not None: - assert usage.prompt_tokens > 0 - assert usage.completion_tokens > 0 - assert usage.total_tokens > 0 + assert usage.prompt_tokens > 0, f"usage.prompt_tokens was zero" + assert usage.completion_tokens > 0, f"usage.completion_tokens was zero" + assert usage.total_tokens > 0, f"usage.total_tokens was zero" + continue + + if ( + hasattr(response.choices[0], "hidden_states") + and response.choices[0].hidden_states is not None + ): + hidden_states = response.choices[0].hidden_states continue index = response.choices[0].index is_first = is_firsts.get(index, True) if logprobs: - assert response.choices[0].logprobs - assert isinstance(response.choices[0].logprobs.tokens[0], str) + assert response.choices[0].logprobs, f"no logprobs in response" + assert isinstance( + response.choices[0].logprobs.tokens[0], str + ), f"{response.choices[0].logprobs.tokens[0]} is not a string" if not (is_first and echo): assert isinstance( response.choices[0].logprobs.top_logprobs[0], dict - ) + ), f"top_logprobs was not a dictionary" ret_num_top_logprobs = len( response.choices[0].logprobs.top_logprobs[0] ) # FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some output id maps to the same output token and duplicate in the map # assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}" - assert ret_num_top_logprobs > 0 + assert ret_num_top_logprobs > 0, f"ret_num_top_logprobs was 0" if is_first: if echo: @@ -164,15 +203,29 @@ class TestOpenAIServer(CustomTestCase): prompt ), f"{response.choices[0].text} and all args {echo} {logprobs} {token_input} {is_first}" is_firsts[index] = False - assert response.id - assert response.created + assert response.id, f"no id in response" + assert response.created, f"no created in response" for index in [i for i in range(parallel_sample_num * num_choices)]: assert not is_firsts.get( index, True ), f"index {index} is not found in the response" - def run_chat_completion(self, logprobs, parallel_sample_num): + if return_hidden_states: + assert hidden_states is not None, "hidden_states is not returned" + try: + hidden_states = np.asarray(hidden_states) + except Exception as e: + raise Exception(f"Failed to convert hidden states to numpy array: {e}") + assert ( + len(hidden_states.shape) == 1 + ), f"hidden_states shape is not correct, was {hidden_states.shape}" + else: + assert ( + hidden_states is None + ), "hidden_states was returned and should not have been" + + def run_chat_completion(self, logprobs, parallel_sample_num, return_hidden_states): client = openai.Client(api_key=self.api_key, base_url=self.base_url) response = client.chat.completions.create( model=self.model, @@ -187,6 +240,7 @@ class TestOpenAIServer(CustomTestCase): logprobs=logprobs is not None and logprobs > 0, top_logprobs=logprobs, n=parallel_sample_num, + extra_body=dict(return_hidden_states=return_hidden_states), ) if logprobs: @@ -210,7 +264,21 @@ class TestOpenAIServer(CustomTestCase): assert response.usage.completion_tokens > 0 assert response.usage.total_tokens > 0 - def run_chat_completion_stream(self, logprobs, parallel_sample_num=1): + if return_hidden_states: + hidden_states = response.choices[0].hidden_states + assert hidden_states is not None, "hidden_states is not returned" + hidden_states = np.asarray(hidden_states) + assert ( + len(hidden_states.shape) == 1 + ), f"hidden_states shape is not correct, was {hidden_states.shape}" + else: + assert not hasattr( + response.choices[0], "hidden_states" + ), "hidden_states was returned and should not have been" + + def run_chat_completion_stream( + self, logprobs, parallel_sample_num=1, return_hidden_states=False + ): client = openai.Client(api_key=self.api_key, base_url=self.base_url) generator = client.chat.completions.create( model=self.model, @@ -224,40 +292,55 @@ class TestOpenAIServer(CustomTestCase): stream=True, stream_options={"include_usage": True}, n=parallel_sample_num, + extra_body=dict(return_hidden_states=return_hidden_states), ) is_firsts = {} + hidden_states = None + top_logprob_tokens = [] for response in generator: usage = response.usage if usage is not None: - assert usage.prompt_tokens > 0 - assert usage.completion_tokens > 0 - assert usage.total_tokens > 0 + assert usage.prompt_tokens > 0, f"usage.prompt_tokens was zero" + assert usage.completion_tokens > 0, f"usage.completion_tokens was zero" + assert usage.total_tokens > 0, f"usage.total_tokens was zero" + continue + + if hasattr(response.choices[0].delta, "hidden_states"): + hidden_states = response.choices[0].delta.hidden_states continue index = response.choices[0].index data = response.choices[0].delta if is_firsts.get(index, True): - assert data.role == "assistant" + assert ( + data.role == "assistant" + ), f"data.role was not 'assistant' for first chunk" is_firsts[index] = False continue if logprobs: - assert response.choices[0].logprobs + assert response.choices[0].logprobs, f"logprobs was not returned" assert isinstance( response.choices[0].logprobs.content[0].top_logprobs[0].token, str - ) + ), f"top_logprobs token was not a string" assert isinstance( response.choices[0].logprobs.content[0].top_logprobs, list - ) + ), f"top_logprobs was not a list" ret_num_top_logprobs = len( response.choices[0].logprobs.content[0].top_logprobs ) assert ( ret_num_top_logprobs == logprobs ), f"{ret_num_top_logprobs} vs {logprobs}" + top_logprob_tokens.append( + response.choices[0].logprobs.content[0].top_logprobs[0].token + ) + assert ( + len(top_logprob_tokens) <= 2 or len(set(top_logprob_tokens)) > 1 + ), "Top Logprob tokens should not consistent of the same token repeated" assert ( isinstance(data.content, str) or isinstance(data.reasoning_content, str) @@ -272,6 +355,20 @@ class TestOpenAIServer(CustomTestCase): index, True ), f"index {index} is not found in the response" + if return_hidden_states: + assert hidden_states is not None, "hidden_states is not returned" + try: + hidden_states = np.asarray(hidden_states) + except Exception as e: + raise Exception(f"Failed to convert hidden states to numpy array: {e}") + assert ( + len(hidden_states.shape) == 1 + ), f"hidden_states shape is not correct, was {hidden_states.shape}" + else: + assert ( + hidden_states is None + ), "hidden_states was returned and should not have been" + def _create_batch(self, mode, client): if mode == "completion": input_file_path = "complete_input.jsonl" @@ -419,43 +516,53 @@ class TestOpenAIServer(CustomTestCase): assert del_response.deleted def test_completion(self): - for echo in [False, True]: - for logprobs in [None, 5]: - for use_list_input in [True, False]: - for parallel_sample_num in [1, 2]: - for token_input in [False, True]: - self.run_completion( - echo, - logprobs, - use_list_input, - parallel_sample_num, - token_input, - ) + for return_hidden_states in [False, True]: + for echo in [False, True]: + for logprobs in [None, 5]: + for use_list_input in [True, False]: + for parallel_sample_num in [1, 2]: + for token_input in [False, True]: + self.run_completion( + echo, + logprobs, + use_list_input, + parallel_sample_num, + token_input, + return_hidden_states, + ) def test_completion_stream(self): # parallel sampling and list input are not supported in streaming mode - for echo in [False, True]: - for logprobs in [None, 5]: - for use_list_input in [True, False]: - for parallel_sample_num in [1, 2]: - for token_input in [False, True]: - self.run_completion_stream( - echo, - logprobs, - use_list_input, - parallel_sample_num, - token_input, - ) + for return_hidden_states in [False, True]: + for echo in [False, True]: + for logprobs in [None, 5]: + for use_list_input in [True, False]: + for parallel_sample_num in [1, 2]: + for token_input in [False, True]: + self.run_completion_stream( + echo, + logprobs, + use_list_input, + parallel_sample_num, + token_input, + return_hidden_states, + ) def test_chat_completion(self): - for logprobs in [None, 5]: - for parallel_sample_num in [1, 2]: - self.run_chat_completion(logprobs, parallel_sample_num) + for return_hidden_states in [False, True]: + for logprobs in [None, 5]: + for parallel_sample_num in [1, 2]: + self.run_chat_completion( + logprobs, parallel_sample_num, return_hidden_states + ) def test_chat_completion_stream(self): - for logprobs in [None, 5]: - for parallel_sample_num in [1, 2]: - self.run_chat_completion_stream(logprobs, parallel_sample_num) + for return_hidden_states in [False, True]: + for logprobs in [None, 5]: + for parallel_sample_num in [1, 2]: + self.run_chat_completion_stream( + logprobs, parallel_sample_num, return_hidden_states + ) def test_batch(self): for mode in ["completion", "chat"]: