diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md index 7ebeac5ce..171ff13cd 100644 --- a/docs/backend/server_arguments.md +++ b/docs/backend/server_arguments.md @@ -135,6 +135,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s | `download_dir` | Overrides the default Hugging Face cache directory for model weights. | None | | `base_gpu_id` | Sets the first GPU to use when distributing the model across multiple GPUs. | `0` | | `allow_auto_truncate`| Automatically truncate requests that exceed the maximum input length. | `False` | +| `enable_return_hidden_states` | Enables returning hidden states to the user. | `False` | ## Logging diff --git a/examples/runtime/hidden_states/hidden_states_engine.py b/examples/runtime/hidden_states/hidden_states_engine.py index 8af883ab1..60ab302ca 100644 --- a/examples/runtime/hidden_states/hidden_states_engine.py +++ b/examples/runtime/hidden_states/hidden_states_engine.py @@ -22,6 +22,7 @@ def main(): # Create an LLM. llm = sgl.Engine( model_path="Alibaba-NLP/gte-Qwen2-1.5B-instruct", + enable_return_hidden_states=True, ) sampling_params = { diff --git a/examples/runtime/hidden_states/hidden_states_server.py b/examples/runtime/hidden_states/hidden_states_server.py index 96045fad9..b04f74372 100644 --- a/examples/runtime/hidden_states/hidden_states_server.py +++ b/examples/runtime/hidden_states/hidden_states_server.py @@ -23,7 +23,7 @@ else: def main(): # Launch the server server_process, port = launch_server_cmd( - "python -m sglang.launch_server --model-path Alibaba-NLP/gte-Qwen2-1.5B-instruct --host 0.0.0.0" + "python -m sglang.launch_server --model-path Alibaba-NLP/gte-Qwen2-1.5B-instruct --enable-return-hidden-states --host 0.0.0.0" ) wait_for_server(f"http://localhost:{port}") diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 40c220c4e..7af81814f 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -99,7 +99,7 @@ class GenerateReqInput: custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None # Whether to return hidden states - return_hidden_states: bool = False + return_hidden_states: Union[List[bool], bool] = False # For disaggregated inference bootstrap_host: Optional[Union[List[str], str]] = None @@ -409,7 +409,11 @@ class GenerateReqInput: if self.custom_logit_processor is not None else None ), - return_hidden_states=self.return_hidden_states, + return_hidden_states=( + self.return_hidden_states[i] + if isinstance(self.return_hidden_states, list) + else self.return_hidden_states + ), # if `__getitem__` is called, the bootstrap_host, bootstrap_port, bootstrap_room must be a list bootstrap_host=( self.bootstrap_host[i] if self.bootstrap_host is not None else None diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index d71bbdf07..e6c3189cb 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -418,6 +418,20 @@ class TokenizerManager: obj.normalize_batch_and_arguments() + if isinstance(obj, GenerateReqInput): + return_hidden_states = obj.return_hidden_states + has_return_hidden_states = return_hidden_states == True or ( + isinstance(return_hidden_states, list) and any(return_hidden_states) + ) + if ( + not self.server_args.enable_return_hidden_states + and has_return_hidden_states + ): + raise ValueError( + "return_hidden_states=True requires the server to be started " + "with --enable-return-hidden-states (ServerArgs.enable_return_hidden_states)." + ) + if self.log_requests: max_length, skip_names, _ = self.log_request_metadata logger.info( diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index be8a1ad14..51a083eb6 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -235,6 +235,10 @@ class CudaGraphRunner: self.model_runner.server_args.speculative_num_draft_tokens ) + # If returning hidden states is enabled, set initial capture hidden mode to full to avoid double-capture on startup + if model_runner.server_args.enable_return_hidden_states: + self.capture_hidden_mode = CaptureHiddenMode.FULL + # Attention backend self.max_bs = max(self.capture_bs) self.max_num_token = self.max_bs * self.num_tokens_per_bs @@ -342,11 +346,29 @@ class CudaGraphRunner: else True ) + requested_capture_hidden_mode = max( + forward_batch.capture_hidden_mode, + ( + forward_batch.spec_info.capture_hidden_mode + if getattr(forward_batch.spec_info, "capture_hidden_mode", None) + is not None + else CaptureHiddenMode.NULL + ), + ) + capture_hidden_mode_matches = ( + requested_capture_hidden_mode == CaptureHiddenMode.NULL + or requested_capture_hidden_mode == self.capture_hidden_mode + ) is_tbo_supported = ( forward_batch.can_run_tbo if self.enable_two_batch_overlap else True ) - return is_bs_supported and is_encoder_lens_supported and is_tbo_supported + return ( + is_bs_supported + and is_encoder_lens_supported + and is_tbo_supported + and capture_hidden_mode_matches + ) def capture(self) -> None: profile_context = empty_context() @@ -541,21 +563,34 @@ class CudaGraphRunner: return graph, out def recapture_if_needed(self, forward_batch: ForwardBatch): - # If the capture_hidden_mode changes, we need to recapture the graph - hidden_mode_from_spec_info = getattr( + + # If the required capture_hidden_mode changes, we need to recapture the graph + + # These are the different factors that can influence the capture_hidden_mode + capture_hidden_mode_required_by_forward_batch = ( + forward_batch.capture_hidden_mode + ) + capture_hidden_mode_required_by_spec_info = getattr( forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL ) - if ( - forward_batch.capture_hidden_mode == CaptureHiddenMode.FULL - and self.capture_hidden_mode != CaptureHiddenMode.FULL - ): - self.capture_hidden_mode = CaptureHiddenMode.FULL - self.capture() - elif ( - forward_batch.capture_hidden_mode != CaptureHiddenMode.FULL - and self.capture_hidden_mode != hidden_mode_from_spec_info - ): - self.capture_hidden_mode = hidden_mode_from_spec_info + capture_hidden_mode_required_for_returning_hidden_states = ( + CaptureHiddenMode.FULL + if self.model_runner.server_args.enable_return_hidden_states + else CaptureHiddenMode.NULL + ) + + # Determine the highest capture_hidden_mode required + # (If we have FULL, we can emulate LAST or NULL) + # (If we have LAST, we can emulate NULL) + required_capture_hidden_mode = max( + capture_hidden_mode_required_by_forward_batch, + capture_hidden_mode_required_by_spec_info, + capture_hidden_mode_required_for_returning_hidden_states, + ) + + # If the current hidden mode is no longer aligned with the required hidden mode, we need to set it to what is required and re-capture + if self.capture_hidden_mode != required_capture_hidden_mode: + self.capture_hidden_mode = required_capture_hidden_mode self.capture() def replay_prepare( diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 2d0328e29..1205ebee6 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -31,6 +31,7 @@ from __future__ import annotations from dataclasses import dataclass from enum import IntEnum, auto +from functools import total_ordering from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import torch @@ -117,13 +118,14 @@ class ForwardMode(IntEnum): return self == ForwardMode.DECODE or self == ForwardMode.IDLE +@total_ordering class CaptureHiddenMode(IntEnum): # Do not capture anything. - NULL = auto() - # Capture hidden states of all tokens. - FULL = auto() + NULL = 0 # Capture a hidden state of the last token. - LAST = auto() + LAST = 1 + # Capture hidden states of all tokens. + FULL = 2 def need_capture(self): return self != CaptureHiddenMode.NULL @@ -134,6 +136,9 @@ class CaptureHiddenMode(IntEnum): def is_last(self): return self == CaptureHiddenMode.LAST + def __lt__(self, other): + return self.value < other.value + @dataclass class ForwardBatch: diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index 31af46bb9..aad9de93e 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -542,6 +542,7 @@ def v1_generate_request( logprob_start_lens = [] top_logprobs_nums = [] lora_paths = [] + return_hidden_states = [] for request in all_requests: # NOTE: with openai API, the prompt's logprobs are always not computed @@ -588,6 +589,7 @@ def v1_generate_request( top_logprobs_nums.append( request.logprobs if request.logprobs is not None else 0 ) + return_hidden_states.append(request.return_hidden_states) if len(all_requests) == 1: if isinstance(prompts[0], str) or isinstance(prompts[0][0], str): @@ -599,6 +601,7 @@ def v1_generate_request( logprob_start_lens = logprob_start_lens[0] top_logprobs_nums = top_logprobs_nums[0] lora_paths = lora_paths[0] + return_hidden_states = return_hidden_states[0] else: if isinstance(prompts[0], str) or isinstance(prompts[0][0], str): prompt_kwargs = {"text": prompts} @@ -615,6 +618,7 @@ def v1_generate_request( stream=all_requests[0].stream, rid=request_ids, lora_path=lora_paths, + return_hidden_states=return_hidden_states, bootstrap_host=all_requests[0].bootstrap_host, bootstrap_port=all_requests[0].bootstrap_port, bootstrap_room=all_requests[0].bootstrap_room, @@ -683,6 +687,16 @@ def v1_generate_response( else: logprobs = None + hidden_states = None + if isinstance(request, list) and request[idx].return_hidden_states: + hidden_states = ret_item["meta_info"].get("hidden_states", None) + elif (not isinstance(request, list)) and request.return_hidden_states: + hidden_states = ret_item["meta_info"].get("hidden_states", None) + if hidden_states is not None: + hidden_states = ( + hidden_states[-1] if hidden_states and len(hidden_states) > 1 else [] + ) + finish_reason = ret_item["meta_info"]["finish_reason"] if to_file: @@ -698,6 +712,8 @@ def v1_generate_response( else None ), } + if hidden_states is not None: + choice_data["hidden_states"] = hidden_states else: choice_data = CompletionResponseChoice( index=idx, @@ -709,6 +725,7 @@ def v1_generate_response( if finish_reason and "matched" in finish_reason else None ), + hidden_states=hidden_states, ) choices.append(choice_data) @@ -777,6 +794,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request): prompt_tokens = {} completion_tokens = {} cached_tokens = {} + hidden_states = {} try: async for content in tokenizer_manager.generate_request( @@ -791,6 +809,9 @@ async def v1_completions(tokenizer_manager, raw_request: Request): prompt_tokens[index] = content["meta_info"]["prompt_tokens"] completion_tokens[index] = content["meta_info"]["completion_tokens"] cached_tokens[index] = content["meta_info"].get("cached_tokens", 0) + hidden_states[index] = content["meta_info"].get( + "hidden_states", None + ) or hidden_states.get(index) if not stream_buffer: # The first chunk if request.echo: @@ -873,6 +894,27 @@ async def v1_completions(tokenizer_manager, raw_request: Request): n_prev_tokens[index] = n_prev_token yield f"data: {chunk.model_dump_json()}\n\n" + if request.return_hidden_states and hidden_states: + for index, choice_hidden_states in hidden_states.items(): + last_token_hidden_states = ( + choice_hidden_states[-1] + if choice_hidden_states and len(choice_hidden_states) > 1 + else [] + ) + hidden_states_chunk = CompletionStreamResponse( + id=content["meta_info"]["id"], + created=created, + choices=[ + CompletionResponseStreamChoice( + text="", + index=index, + hidden_states=last_token_hidden_states, + finish_reason=None, + ) + ], + model=request.model, + ) + yield f"data: {hidden_states_chunk.model_dump_json()}\n\n" if request.stream_options and request.stream_options.include_usage: total_prompt_tokens = sum( tokens @@ -973,6 +1015,7 @@ def v1_chat_generate_request( top_logprobs_nums = [] modalities_list = [] lora_paths = [] + return_hidden_states = [] # NOTE: with openai API, the prompt's logprobs are always not computed @@ -1215,6 +1258,7 @@ def v1_chat_generate_request( image_data_list.append(image_data) audio_data_list.append(audio_data) modalities_list.append(modalities) + return_hidden_states.append(request.return_hidden_states) if len(all_requests) == 1: if is_multimodal: # processor will need text input @@ -1233,6 +1277,7 @@ def v1_chat_generate_request( modalities_list = modalities_list[0] lora_paths = lora_paths[0] request_ids = request_ids[0] + return_hidden_states = return_hidden_states[0] else: if tokenizer_manager.model_config.is_multimodal: # processor will need text input @@ -1259,6 +1304,7 @@ def v1_chat_generate_request( bootstrap_host=all_requests[0].bootstrap_host, bootstrap_port=all_requests[0].bootstrap_port, bootstrap_room=all_requests[0].bootstrap_room, + return_hidden_states=return_hidden_states, ) return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0] @@ -1319,6 +1365,20 @@ def v1_chat_generate_response( else: choice_logprobs = None + if isinstance(request, list) and request[idx].return_hidden_states: + include_hidden_states = True + elif not isinstance(request, list) and request.return_hidden_states: + include_hidden_states = True + else: + include_hidden_states = False + if include_hidden_states and ret_item["meta_info"].get("hidden_states", None): + hidden_states = ret_item["meta_info"]["hidden_states"] + hidden_states = ( + hidden_states[-1] if hidden_states and len(hidden_states) > 1 else [] + ) + else: + hidden_states = None + finish_reason = ret_item["meta_info"]["finish_reason"] tool_calls = None @@ -1391,6 +1451,8 @@ def v1_chat_generate_response( else None ), } + if hidden_states is not None: + choice_data["hidden_states"] = hidden_states else: choice_data = ChatCompletionResponseChoice( index=idx, @@ -1407,6 +1469,7 @@ def v1_chat_generate_response( if finish_reason and "matched" in finish_reason else None ), + hidden_states=hidden_states, ) choices.append(choice_data) @@ -1486,12 +1549,16 @@ async def v1_chat_completions( prompt_tokens = {} completion_tokens = {} cached_tokens = {} + hidden_states = {} try: async for content in tokenizer_manager.generate_request( adapted_request, raw_request ): index = content.get("index", 0) text = content["text"] + hidden_states[index] = content["meta_info"].get( + "hidden_states", None + ) or hidden_states.get(index) is_first = is_firsts.get(index, True) stream_buffer = stream_buffers.get(index, "") @@ -1613,6 +1680,7 @@ async def v1_chat_completions( if (delta and len(delta) == 0) or not delta: stream_buffers[index] = new_stream_buffer is_firsts[index] = is_first + n_prev_tokens[index] = n_prev_token continue if request.tool_choice != "none" and request.tools: @@ -1702,6 +1770,7 @@ async def v1_chat_completions( stream_buffers[index] = new_stream_buffer is_firsts[index] = is_first + n_prev_tokens[index] = n_prev_token else: # No tool calls => just treat this as normal text @@ -1734,6 +1803,7 @@ async def v1_chat_completions( yield f"data: {chunk.model_dump_json()}\n\n" stream_buffers[index] = new_stream_buffer is_firsts[index] = is_first + n_prev_tokens[index] = n_prev_token if finish_reason_type == "stop" and request.tool_choice != "none": parser = FunctionCallParser( tools=request.tools, @@ -1769,6 +1839,28 @@ async def v1_chat_completions( else: usage = None + if request.return_hidden_states and hidden_states: + for index, choice_hidden_states in hidden_states.items(): + last_token_hidden_states = ( + choice_hidden_states[-1] + if choice_hidden_states and len(choice_hidden_states) > 1 + else [] + ) + hidden_states_chunk = ChatCompletionStreamResponse( + id=content["meta_info"]["id"], + created=created, + choices=[ + ChatCompletionResponseStreamChoice( + index=index, + delta=DeltaMessage( + hidden_states=last_token_hidden_states + ), + finish_reason=finish_reason_type, + ) + ], + model=request.model, + ) + yield f"data: {hidden_states_chunk.model_dump_json()}\n\n" final_usage_chunk = ChatCompletionStreamResponse( id=content["meta_info"]["id"], created=created, diff --git a/python/sglang/srt/openai_api/protocol.py b/python/sglang/srt/openai_api/protocol.py index 351c1c567..2d2b76155 100644 --- a/python/sglang/srt/openai_api/protocol.py +++ b/python/sglang/srt/openai_api/protocol.py @@ -16,7 +16,7 @@ import time from typing import Dict, List, Optional, Union -from pydantic import BaseModel, Field, root_validator +from pydantic import BaseModel, Field, model_serializer, root_validator from typing_extensions import Literal @@ -182,6 +182,7 @@ class CompletionRequest(BaseModel): skip_special_tokens: bool = True lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None session_params: Optional[Dict] = None + return_hidden_states: Optional[bool] = False # For PD disaggregation bootstrap_host: Optional[str] = None @@ -195,6 +196,11 @@ class CompletionResponseChoice(BaseModel): logprobs: Optional[LogProbs] = None finish_reason: Literal["stop", "length", "content_filter", "abort"] matched_stop: Union[None, int, str] = None + hidden_states: Optional[object] = None + + @model_serializer + def _serialize(self): + return exclude_if_none(self, ["hidden_states"]) class CompletionResponse(BaseModel): @@ -212,6 +218,11 @@ class CompletionResponseStreamChoice(BaseModel): logprobs: Optional[LogProbs] = None finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None matched_stop: Union[None, int, str] = None + hidden_states: Optional[object] = None + + @model_serializer + def _serialize(self): + return exclude_if_none(self, ["hidden_states"]) class CompletionStreamResponse(BaseModel): @@ -405,6 +416,9 @@ class ChatCompletionRequest(BaseModel): bootstrap_port: Optional[int] = None bootstrap_room: Optional[int] = None + # Hidden States + return_hidden_states: Optional[bool] = False + class ChatMessage(BaseModel): role: Optional[str] = None @@ -421,6 +435,11 @@ class ChatCompletionResponseChoice(BaseModel): "stop", "length", "tool_calls", "content_filter", "function_call", "abort" ] matched_stop: Union[None, int, str] = None + hidden_states: Optional[object] = None + + @model_serializer + def _serialize(self): + return exclude_if_none(self, ["hidden_states"]) class ChatCompletionResponse(BaseModel): @@ -437,6 +456,11 @@ class DeltaMessage(BaseModel): content: Optional[str] = None reasoning_content: Optional[str] = None tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None]) + hidden_states: Optional[object] = None + + @model_serializer + def _serialize(self): + return exclude_if_none(self, ["hidden_states"]) class ChatCompletionResponseStreamChoice(BaseModel): @@ -513,3 +537,8 @@ class ScoringResponse(BaseModel): model: str usage: Optional[UsageInfo] = None object: str = "scoring" + + +def exclude_if_none(obj, field_names: List[str]): + omit_if_none_fields = {k for k, v in obj.model_fields.items() if k in field_names} + return {k: v for k, v in obj if k not in omit_if_none_fields or v is not None} diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 7e4571333..e1fc3be03 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -215,6 +215,7 @@ class ServerArgs: disable_chunked_prefix_cache: bool = False disable_fast_image_processor: bool = False warmups: Optional[str] = None + enable_return_hidden_states: bool = False # Debug tensor dumps debug_tensor_dump_output_folder: Optional[str] = None @@ -1456,6 +1457,12 @@ class ServerArgs: default=ServerArgs.debug_tensor_dump_inject, help="Inject the outputs from jax as the input of every layer.", ) + + parser.add_argument( + "--enable-return-hidden-states", + action="store_true", + help="Enable returning hidden states with responses.", + ) parser.add_argument( "--debug-tensor-dump-prefill-only", action="store_true", diff --git a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py index 6d8c10047..001404bd8 100644 --- a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py @@ -117,9 +117,7 @@ class EAGLEDraftCudaGraphRunner: hidden_states = self.hidden_states[:num_seqs] spec_info = EagleDraftInput( - topk_p=topk_p, - topk_index=topk_index, - hidden_states=hidden_states, + topk_p=topk_p, topk_index=topk_index, hidden_states=hidden_states ) # Forward batch diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index bc0b50f31..0597ad4e0 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -290,6 +290,7 @@ class EAGLEWorker(TpModelWorker): A tuple of the final logit output of the target model, next tokens accepted, the batch id (used for overlap schedule), and number of accepted tokens. """ + if batch.forward_mode.is_decode(): with self.draft_tp_context(self.draft_model_runner.tp_group): spec_info = self.draft(batch) @@ -431,10 +432,10 @@ class EAGLEWorker(TpModelWorker): batch.out_cache_loc = out_cache_loc batch.seq_lens_sum = torch.sum(batch.seq_lens).item() spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0) - - # Get forward batch spec_info.capture_hidden_mode = CaptureHiddenMode.LAST + batch.return_hidden_states = False model_worker_batch = batch.get_model_worker_batch() + assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST forward_batch = ForwardBatch.init_new( model_worker_batch, self.draft_model_runner ) @@ -547,11 +548,13 @@ class EAGLEWorker(TpModelWorker): def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput): spec_info.prepare_for_verify(batch, self.page_size) + batch.return_hidden_states = False batch.forward_mode = ForwardMode.TARGET_VERIFY batch.spec_info = spec_info model_worker_batch = batch.get_model_worker_batch( seq_lens_cpu_cache=spec_info.seq_lens_cpu ) + assert model_worker_batch.capture_hidden_mode == spec_info.capture_hidden_mode if batch.has_grammar: retrieve_next_token_cpu = spec_info.retrive_next_token.cpu() @@ -687,15 +690,18 @@ class EAGLEWorker(TpModelWorker): hidden_states: Hidden states from the target model forward next_token_ids: Next token ids generated from the target forward. """ + # Sometimes we get hidden states produced by CaptureHiddenMode.FULL, so we have to select just the last batch.spec_info = EagleDraftInput( hidden_states=hidden_states, verified_id=next_token_ids, ) + batch.return_hidden_states = False batch.spec_info.prepare_for_extend(batch) batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST model_worker_batch = batch.get_model_worker_batch( seq_lens_cpu_cache=seq_lens_cpu ) + assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST forward_batch = ForwardBatch.init_new( model_worker_batch, self.draft_model_runner ) @@ -718,7 +724,9 @@ class EAGLEWorker(TpModelWorker): batch, self.speculative_num_steps, ) + batch.return_hidden_states = False model_worker_batch = batch.get_model_worker_batch() + assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST forward_batch = ForwardBatch.init_new( model_worker_batch, self.draft_model_runner ) diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 296982e09..b1883e1a9 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -59,6 +59,7 @@ suites = { TestFile("test_openai_adapter.py", 1), TestFile("test_openai_function_calling.py", 60), TestFile("test_openai_server.py", 149), + TestFile("test_openai_server_hidden_states.py", 240), TestFile("test_penalty.py", 41), TestFile("test_page_size.py", 60), TestFile("test_pytorch_sampling_backend.py", 66), diff --git a/test/srt/test_hidden_states.py b/test/srt/test_hidden_states.py index 81e42f7b1..2046ce529 100644 --- a/test/srt/test_hidden_states.py +++ b/test/srt/test_hidden_states.py @@ -23,6 +23,7 @@ class TestHiddenState(CustomTestCase): model_path=model_path, random_seed=42, skip_tokenizer_init=True, + enable_return_hidden_states=True, ) outputs = engine.generate( input_ids=input_ids, @@ -96,6 +97,7 @@ class TestHiddenState(CustomTestCase): model_path=model_path, random_seed=42, skip_tokenizer_init=True, + enable_return_hidden_states=True, ) outputs_completion_first_round = engine.generate( input_ids=input_ids, diff --git a/test/srt/test_io_struct.py b/test/srt/test_io_struct.py index 1be077366..b8fdec8ec 100644 --- a/test/srt/test_io_struct.py +++ b/test/srt/test_io_struct.py @@ -381,12 +381,14 @@ class TestGenerateReqInputNormalization(CustomTestCase): logprob_start_len=[10, 5], top_logprobs_num=[5, 3], token_ids_logprob=[[7, 8, 9], [4, 5, 6]], + return_hidden_states=[False, False, True], ) req.normalize_batch_and_arguments() self.assertEqual(req.return_logprob, [True, False]) self.assertEqual(req.logprob_start_len, [10, 5]) self.assertEqual(req.top_logprobs_num, [5, 3]) self.assertEqual(req.token_ids_logprob, [[7, 8, 9], [4, 5, 6]]) + self.assertEqual(req.return_hidden_states, [False, False, True]) def test_custom_logit_processor_normalization(self): """Test normalization of custom_logit_processor.""" diff --git a/test/srt/test_openai_server.py b/test/srt/test_openai_server.py index 2eec90ba3..d10b953c0 100644 --- a/test/srt/test_openai_server.py +++ b/test/srt/test_openai_server.py @@ -1,7 +1,9 @@ """ python3 -m unittest test_openai_server.TestOpenAIServer.test_batch python3 -m unittest test_openai_server.TestOpenAIServer.test_completion - +python3 -m unittest test_openai_server.TestOpenAIServer.test_completion_stream +python3 -m unittest test_openai_server.TestOpenAIServer.test_chat_completion +python3 -m unittest test_openai_server.TestOpenAIServer.test_chat_completion_stream """ import json @@ -9,6 +11,7 @@ import re import time import unittest +import numpy as np import openai import requests @@ -137,27 +140,29 @@ class TestOpenAIServer(CustomTestCase): for response in generator: usage = response.usage if usage is not None: - assert usage.prompt_tokens > 0 - assert usage.completion_tokens > 0 - assert usage.total_tokens > 0 + assert usage.prompt_tokens > 0, f"usage.prompt_tokens was zero" + assert usage.completion_tokens > 0, f"usage.completion_tokens was zero" + assert usage.total_tokens > 0, f"usage.total_tokens was zero" continue index = response.choices[0].index is_first = is_firsts.get(index, True) if logprobs: - assert response.choices[0].logprobs - assert isinstance(response.choices[0].logprobs.tokens[0], str) + assert response.choices[0].logprobs, f"no logprobs in response" + assert isinstance( + response.choices[0].logprobs.tokens[0], str + ), f"{response.choices[0].logprobs.tokens[0]} is not a string" if not (is_first and echo): assert isinstance( response.choices[0].logprobs.top_logprobs[0], dict - ) + ), f"top_logprobs was not a dictionary" ret_num_top_logprobs = len( response.choices[0].logprobs.top_logprobs[0] ) # FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some output id maps to the same output token and duplicate in the map # assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}" - assert ret_num_top_logprobs > 0 + assert ret_num_top_logprobs > 0, f"ret_num_top_logprobs was 0" if is_first: if echo: @@ -165,8 +170,8 @@ class TestOpenAIServer(CustomTestCase): prompt ), f"{response.choices[0].text} and all args {echo} {logprobs} {token_input} {is_first}" is_firsts[index] = False - assert response.id - assert response.created + assert response.id, f"no id in response" + assert response.created, f"no created in response" for index in [i for i in range(parallel_sample_num * num_choices)]: assert not is_firsts.get( @@ -231,27 +236,29 @@ class TestOpenAIServer(CustomTestCase): for response in generator: usage = response.usage if usage is not None: - assert usage.prompt_tokens > 0 - assert usage.completion_tokens > 0 - assert usage.total_tokens > 0 + assert usage.prompt_tokens > 0, f"usage.prompt_tokens was zero" + assert usage.completion_tokens > 0, f"usage.completion_tokens was zero" + assert usage.total_tokens > 0, f"usage.total_tokens was zero" continue index = response.choices[0].index data = response.choices[0].delta if is_firsts.get(index, True): - assert data.role == "assistant" + assert ( + data.role == "assistant" + ), f"data.role was not 'assistant' for first chunk" is_firsts[index] = False continue if logprobs: - assert response.choices[0].logprobs + assert response.choices[0].logprobs, f"logprobs was not returned" assert isinstance( response.choices[0].logprobs.content[0].top_logprobs[0].token, str - ) + ), f"top_logprobs token was not a string" assert isinstance( response.choices[0].logprobs.content[0].top_logprobs, list - ) + ), f"top_logprobs was not a list" ret_num_top_logprobs = len( response.choices[0].logprobs.content[0].top_logprobs ) diff --git a/test/srt/test_openai_server_hidden_states.py b/test/srt/test_openai_server_hidden_states.py new file mode 100644 index 000000000..34e5ddde7 --- /dev/null +++ b/test/srt/test_openai_server_hidden_states.py @@ -0,0 +1,356 @@ +import json +import re +import time +import unittest +from abc import ABC + +import numpy as np +import openai +import torch + +from sglang.srt.hf_transformers_utils import get_tokenizer +from sglang.srt.utils import kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, + DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, + DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST, + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + + +class BaseTestOpenAIServerWithHiddenStates(ABC): + + @classmethod + def setUpClass(cls): + cls.return_hidden_states = [False, True] + cls.use_list_input = [True, False] + cls.parallel_sample_nums = [1, 2] + + def test_completion(self): + for return_hidden_states in self.return_hidden_states: + for use_list_input in self.use_list_input: + for parallel_sample_num in self.parallel_sample_nums: + self.run_completion( + use_list_input, + parallel_sample_num, + return_hidden_states, + ) + + def test_completion_stream(self): + # parallel sampling and list input are not supported in streaming mode + for return_hidden_states in self.return_hidden_states: + for use_list_input in self.use_list_input: + for parallel_sample_num in self.parallel_sample_nums: + self.run_completion_stream( + use_list_input, + parallel_sample_num, + return_hidden_states, + ) + + def test_chat_completion(self): + for return_hidden_states in self.return_hidden_states: + for ( + parallel_sample_num + ) in ( + self.parallel_sample_nums + ): # parallel sample num 2 breaks in the adapter with a 400 for EAGLE + self.run_chat_completion(parallel_sample_num, return_hidden_states) + + def test_chat_completion_stream(self): + for return_hidden_states in self.return_hidden_states: + for ( + parallel_sample_num + ) in ( + self.parallel_sample_nums + ): # parallel sample num > 1 breaks in the adapter with a 400 for EAGLE + self.run_chat_completion_stream( + parallel_sample_num, return_hidden_states + ) + + def run_completion( + self, + use_list_input, + parallel_sample_num, + return_hidden_states, + ): + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + prompt = "The capital of France is" + prompt_input = prompt + + if use_list_input: + prompt_arg = [prompt_input, prompt_input] + num_choices = len(prompt_arg) + else: + prompt_arg = prompt_input + num_choices = 1 + + response = client.completions.create( + model=self.model, + prompt=prompt_arg, + temperature=0, + max_tokens=32, + n=parallel_sample_num, + extra_body=dict(return_hidden_states=return_hidden_states), + ) + + for choice in response.choices: + assert hasattr(choice, "hidden_states") == return_hidden_states + if return_hidden_states: + assert choice.hidden_states is not None, "hidden_states was None" + + def run_completion_stream( + self, + use_list_input, + parallel_sample_num, + return_hidden_states, + ): + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + prompt = "The capital of France is" + prompt_input = prompt + num_prompt_tokens = len(self.tokenizer.encode(prompt)) + + if use_list_input: + prompt_arg = [prompt_input, prompt_input] + num_choices = len(prompt_arg) + num_prompt_tokens *= 2 + else: + prompt_arg = prompt_input + num_choices = 1 + + generator = client.completions.create( + model=self.model, + prompt=prompt_arg, + temperature=0, + max_tokens=32, + stream=True, + stream_options={"include_usage": True}, + n=parallel_sample_num, + extra_body=dict(return_hidden_states=return_hidden_states), + ) + + hidden_states_list = [] + for response in generator: + usage = response.usage + for choice in response.choices: + if hasattr(choice, "hidden_states"): + assert return_hidden_states + assert choice.hidden_states is not None + hidden_states_list.append(choice.hidden_states) + + if return_hidden_states: + assert ( + len(hidden_states_list) == parallel_sample_num * num_choices + ), f"Expected {parallel_sample_num * num_choices} hidden states, got {len(hidden_states_list)}" + else: + assert ( + hidden_states_list == [] + ), "hidden_states were returned and should not have been" + + def run_chat_completion(self, parallel_sample_num, return_hidden_states): + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + response = client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": "You are a helpful AI assistant"}, + { + "role": "user", + "content": "What is the capital of France? Answer in a few words.", + }, + ], + temperature=0, + n=parallel_sample_num, + extra_body=dict(return_hidden_states=return_hidden_states), + ) + + for choice in response.choices: + assert hasattr(choice, "hidden_states") == return_hidden_states + if return_hidden_states: + assert choice.hidden_states is not None, "hidden_states was None" + + def run_chat_completion_stream( + self, parallel_sample_num=1, return_hidden_states=False + ): + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + generator = client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": "You are a helpful AI assistant"}, + {"role": "user", "content": "What is the capital of France?"}, + ], + temperature=0, + stream=True, + stream_options={"include_usage": True}, + n=parallel_sample_num, + extra_body=dict(return_hidden_states=return_hidden_states), + ) + + is_firsts = {} + hidden_states_list = [] + + for response in generator: + for choice in response.choices: + if hasattr(choice.delta, "hidden_states"): + assert return_hidden_states + assert choice.delta.hidden_states is not None + hidden_states_list.append(choice.delta.hidden_states) + + if return_hidden_states: + assert ( + len(hidden_states_list) == parallel_sample_num + ), f"Expected {parallel_sample_num} hidden states, got {len(hidden_states_list)}" + else: + assert ( + hidden_states_list == [] + ), "hidden_states were returned and should not have been" + + +class TestOpenAIServerWithHiddenStatesEnabled( + CustomTestCase, BaseTestOpenAIServerWithHiddenStates +): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.api_key = "sk-123456" + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + api_key=cls.api_key, + other_args=["--enable-return-hidden-states"], + ) + cls.base_url += "/v1" + cls.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST) + cls.return_hidden_states = [False, True] + cls.use_list_input = [True, False] + cls.parallel_sample_nums = [1, 2] + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + +class TestOpenAIServerWithHiddenStatesEnabledAndCUDAGraphDisabled( + CustomTestCase, BaseTestOpenAIServerWithHiddenStates +): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.api_key = "sk-123456" + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + api_key=cls.api_key, + other_args=["--enable-return-hidden-states", "--disable-cuda-graph"], + ) + cls.base_url += "/v1" + cls.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST) + cls.return_hidden_states = [False, True] + cls.use_list_input = [True, False] + cls.parallel_sample_nums = [1] + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + +class TestOpenAIServerWithEAGLEAndHiddenStatesEnabled( + CustomTestCase, BaseTestOpenAIServerWithHiddenStates +): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.api_key = "sk-123456" + cls.speculative_draft_model = DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST + cls.speculative_algorithm = "EAGLE" + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--speculative-algorithm", + "EAGLE", + "--speculative-draft-model-path", + DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, + "--speculative-num-steps", + 5, + "--speculative-eagle-topk", + 8, + "--speculative-num-draft-tokens", + 64, + "--mem-fraction-static", + 0.7, + "--chunked-prefill-size", + 128, + "--max-running-requests", + 8, + "--enable-return-hidden-states", + ], + ) + cls.base_url += "/v1" + cls.tokenizer = get_tokenizer(DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST) + cls.return_hidden_states = [False, True] + cls.use_list_input = [True, False] + cls.parallel_sample_nums = [1] + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + +class TestOpenAIServerWithEAGLE3AndHiddenStatesEnabled( + CustomTestCase, BaseTestOpenAIServerWithHiddenStates +): + @classmethod + def setUpClass(cls): + cls.model = "meta-llama/Llama-3.1-8B-Instruct" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.api_key = "sk-123456" + cls.speculative_algorithm = "EAGLE3" + cls.speculative_draft_model = "jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B" + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--speculative-algorithm", + cls.speculative_algorithm, + "--speculative-draft-model-path", + cls.speculative_draft_model, + "--speculative-num-steps", + 5, + "--speculative-eagle-topk", + 16, + "--speculative-num-draft-tokens", + 64, + "--mem-fraction-static", + 0.7, + "--chunked-prefill-size", + 128, + "--max-running-requests", + 8, + "--dtype", + "float16", + "--enable-return-hidden-states", + ], + ) + cls.base_url += "/v1" + cls.tokenizer = get_tokenizer(cls.model) + cls.return_hidden_states = [False, True] + cls.use_list_input = [True, False] + cls.parallel_sample_nums = [1] + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + +if __name__ == "__main__": + unittest.main()