Open AI API hidden states (#6716)
This commit is contained in:
@@ -135,6 +135,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
|
|||||||
| `download_dir` | Overrides the default Hugging Face cache directory for model weights. | None |
|
| `download_dir` | Overrides the default Hugging Face cache directory for model weights. | None |
|
||||||
| `base_gpu_id` | Sets the first GPU to use when distributing the model across multiple GPUs. | `0` |
|
| `base_gpu_id` | Sets the first GPU to use when distributing the model across multiple GPUs. | `0` |
|
||||||
| `allow_auto_truncate`| Automatically truncate requests that exceed the maximum input length. | `False` |
|
| `allow_auto_truncate`| Automatically truncate requests that exceed the maximum input length. | `False` |
|
||||||
|
| `enable_return_hidden_states` | Enables returning hidden states to the user. | `False` |
|
||||||
|
|
||||||
## Logging
|
## Logging
|
||||||
|
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ def main():
|
|||||||
# Create an LLM.
|
# Create an LLM.
|
||||||
llm = sgl.Engine(
|
llm = sgl.Engine(
|
||||||
model_path="Alibaba-NLP/gte-Qwen2-1.5B-instruct",
|
model_path="Alibaba-NLP/gte-Qwen2-1.5B-instruct",
|
||||||
|
enable_return_hidden_states=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
sampling_params = {
|
sampling_params = {
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ else:
|
|||||||
def main():
|
def main():
|
||||||
# Launch the server
|
# Launch the server
|
||||||
server_process, port = launch_server_cmd(
|
server_process, port = launch_server_cmd(
|
||||||
"python -m sglang.launch_server --model-path Alibaba-NLP/gte-Qwen2-1.5B-instruct --host 0.0.0.0"
|
"python -m sglang.launch_server --model-path Alibaba-NLP/gte-Qwen2-1.5B-instruct --enable-return-hidden-states --host 0.0.0.0"
|
||||||
)
|
)
|
||||||
wait_for_server(f"http://localhost:{port}")
|
wait_for_server(f"http://localhost:{port}")
|
||||||
|
|
||||||
|
|||||||
@@ -99,7 +99,7 @@ class GenerateReqInput:
|
|||||||
custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
|
custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
|
||||||
|
|
||||||
# Whether to return hidden states
|
# Whether to return hidden states
|
||||||
return_hidden_states: bool = False
|
return_hidden_states: Union[List[bool], bool] = False
|
||||||
|
|
||||||
# For disaggregated inference
|
# For disaggregated inference
|
||||||
bootstrap_host: Optional[Union[List[str], str]] = None
|
bootstrap_host: Optional[Union[List[str], str]] = None
|
||||||
@@ -409,7 +409,11 @@ class GenerateReqInput:
|
|||||||
if self.custom_logit_processor is not None
|
if self.custom_logit_processor is not None
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
return_hidden_states=self.return_hidden_states,
|
return_hidden_states=(
|
||||||
|
self.return_hidden_states[i]
|
||||||
|
if isinstance(self.return_hidden_states, list)
|
||||||
|
else self.return_hidden_states
|
||||||
|
),
|
||||||
# if `__getitem__` is called, the bootstrap_host, bootstrap_port, bootstrap_room must be a list
|
# if `__getitem__` is called, the bootstrap_host, bootstrap_port, bootstrap_room must be a list
|
||||||
bootstrap_host=(
|
bootstrap_host=(
|
||||||
self.bootstrap_host[i] if self.bootstrap_host is not None else None
|
self.bootstrap_host[i] if self.bootstrap_host is not None else None
|
||||||
|
|||||||
@@ -418,6 +418,20 @@ class TokenizerManager:
|
|||||||
|
|
||||||
obj.normalize_batch_and_arguments()
|
obj.normalize_batch_and_arguments()
|
||||||
|
|
||||||
|
if isinstance(obj, GenerateReqInput):
|
||||||
|
return_hidden_states = obj.return_hidden_states
|
||||||
|
has_return_hidden_states = return_hidden_states == True or (
|
||||||
|
isinstance(return_hidden_states, list) and any(return_hidden_states)
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
not self.server_args.enable_return_hidden_states
|
||||||
|
and has_return_hidden_states
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"return_hidden_states=True requires the server to be started "
|
||||||
|
"with --enable-return-hidden-states (ServerArgs.enable_return_hidden_states)."
|
||||||
|
)
|
||||||
|
|
||||||
if self.log_requests:
|
if self.log_requests:
|
||||||
max_length, skip_names, _ = self.log_request_metadata
|
max_length, skip_names, _ = self.log_request_metadata
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|||||||
@@ -235,6 +235,10 @@ class CudaGraphRunner:
|
|||||||
self.model_runner.server_args.speculative_num_draft_tokens
|
self.model_runner.server_args.speculative_num_draft_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# If returning hidden states is enabled, set initial capture hidden mode to full to avoid double-capture on startup
|
||||||
|
if model_runner.server_args.enable_return_hidden_states:
|
||||||
|
self.capture_hidden_mode = CaptureHiddenMode.FULL
|
||||||
|
|
||||||
# Attention backend
|
# Attention backend
|
||||||
self.max_bs = max(self.capture_bs)
|
self.max_bs = max(self.capture_bs)
|
||||||
self.max_num_token = self.max_bs * self.num_tokens_per_bs
|
self.max_num_token = self.max_bs * self.num_tokens_per_bs
|
||||||
@@ -342,11 +346,29 @@ class CudaGraphRunner:
|
|||||||
else True
|
else True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
requested_capture_hidden_mode = max(
|
||||||
|
forward_batch.capture_hidden_mode,
|
||||||
|
(
|
||||||
|
forward_batch.spec_info.capture_hidden_mode
|
||||||
|
if getattr(forward_batch.spec_info, "capture_hidden_mode", None)
|
||||||
|
is not None
|
||||||
|
else CaptureHiddenMode.NULL
|
||||||
|
),
|
||||||
|
)
|
||||||
|
capture_hidden_mode_matches = (
|
||||||
|
requested_capture_hidden_mode == CaptureHiddenMode.NULL
|
||||||
|
or requested_capture_hidden_mode == self.capture_hidden_mode
|
||||||
|
)
|
||||||
is_tbo_supported = (
|
is_tbo_supported = (
|
||||||
forward_batch.can_run_tbo if self.enable_two_batch_overlap else True
|
forward_batch.can_run_tbo if self.enable_two_batch_overlap else True
|
||||||
)
|
)
|
||||||
|
|
||||||
return is_bs_supported and is_encoder_lens_supported and is_tbo_supported
|
return (
|
||||||
|
is_bs_supported
|
||||||
|
and is_encoder_lens_supported
|
||||||
|
and is_tbo_supported
|
||||||
|
and capture_hidden_mode_matches
|
||||||
|
)
|
||||||
|
|
||||||
def capture(self) -> None:
|
def capture(self) -> None:
|
||||||
profile_context = empty_context()
|
profile_context = empty_context()
|
||||||
@@ -541,21 +563,34 @@ class CudaGraphRunner:
|
|||||||
return graph, out
|
return graph, out
|
||||||
|
|
||||||
def recapture_if_needed(self, forward_batch: ForwardBatch):
|
def recapture_if_needed(self, forward_batch: ForwardBatch):
|
||||||
# If the capture_hidden_mode changes, we need to recapture the graph
|
|
||||||
hidden_mode_from_spec_info = getattr(
|
# If the required capture_hidden_mode changes, we need to recapture the graph
|
||||||
|
|
||||||
|
# These are the different factors that can influence the capture_hidden_mode
|
||||||
|
capture_hidden_mode_required_by_forward_batch = (
|
||||||
|
forward_batch.capture_hidden_mode
|
||||||
|
)
|
||||||
|
capture_hidden_mode_required_by_spec_info = getattr(
|
||||||
forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
|
forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
|
||||||
)
|
)
|
||||||
if (
|
capture_hidden_mode_required_for_returning_hidden_states = (
|
||||||
forward_batch.capture_hidden_mode == CaptureHiddenMode.FULL
|
CaptureHiddenMode.FULL
|
||||||
and self.capture_hidden_mode != CaptureHiddenMode.FULL
|
if self.model_runner.server_args.enable_return_hidden_states
|
||||||
):
|
else CaptureHiddenMode.NULL
|
||||||
self.capture_hidden_mode = CaptureHiddenMode.FULL
|
)
|
||||||
self.capture()
|
|
||||||
elif (
|
# Determine the highest capture_hidden_mode required
|
||||||
forward_batch.capture_hidden_mode != CaptureHiddenMode.FULL
|
# (If we have FULL, we can emulate LAST or NULL)
|
||||||
and self.capture_hidden_mode != hidden_mode_from_spec_info
|
# (If we have LAST, we can emulate NULL)
|
||||||
):
|
required_capture_hidden_mode = max(
|
||||||
self.capture_hidden_mode = hidden_mode_from_spec_info
|
capture_hidden_mode_required_by_forward_batch,
|
||||||
|
capture_hidden_mode_required_by_spec_info,
|
||||||
|
capture_hidden_mode_required_for_returning_hidden_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
# If the current hidden mode is no longer aligned with the required hidden mode, we need to set it to what is required and re-capture
|
||||||
|
if self.capture_hidden_mode != required_capture_hidden_mode:
|
||||||
|
self.capture_hidden_mode = required_capture_hidden_mode
|
||||||
self.capture()
|
self.capture()
|
||||||
|
|
||||||
def replay_prepare(
|
def replay_prepare(
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import IntEnum, auto
|
from enum import IntEnum, auto
|
||||||
|
from functools import total_ordering
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -117,13 +118,14 @@ class ForwardMode(IntEnum):
|
|||||||
return self == ForwardMode.DECODE or self == ForwardMode.IDLE
|
return self == ForwardMode.DECODE or self == ForwardMode.IDLE
|
||||||
|
|
||||||
|
|
||||||
|
@total_ordering
|
||||||
class CaptureHiddenMode(IntEnum):
|
class CaptureHiddenMode(IntEnum):
|
||||||
# Do not capture anything.
|
# Do not capture anything.
|
||||||
NULL = auto()
|
NULL = 0
|
||||||
# Capture hidden states of all tokens.
|
|
||||||
FULL = auto()
|
|
||||||
# Capture a hidden state of the last token.
|
# Capture a hidden state of the last token.
|
||||||
LAST = auto()
|
LAST = 1
|
||||||
|
# Capture hidden states of all tokens.
|
||||||
|
FULL = 2
|
||||||
|
|
||||||
def need_capture(self):
|
def need_capture(self):
|
||||||
return self != CaptureHiddenMode.NULL
|
return self != CaptureHiddenMode.NULL
|
||||||
@@ -134,6 +136,9 @@ class CaptureHiddenMode(IntEnum):
|
|||||||
def is_last(self):
|
def is_last(self):
|
||||||
return self == CaptureHiddenMode.LAST
|
return self == CaptureHiddenMode.LAST
|
||||||
|
|
||||||
|
def __lt__(self, other):
|
||||||
|
return self.value < other.value
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ForwardBatch:
|
class ForwardBatch:
|
||||||
|
|||||||
@@ -542,6 +542,7 @@ def v1_generate_request(
|
|||||||
logprob_start_lens = []
|
logprob_start_lens = []
|
||||||
top_logprobs_nums = []
|
top_logprobs_nums = []
|
||||||
lora_paths = []
|
lora_paths = []
|
||||||
|
return_hidden_states = []
|
||||||
|
|
||||||
for request in all_requests:
|
for request in all_requests:
|
||||||
# NOTE: with openai API, the prompt's logprobs are always not computed
|
# NOTE: with openai API, the prompt's logprobs are always not computed
|
||||||
@@ -588,6 +589,7 @@ def v1_generate_request(
|
|||||||
top_logprobs_nums.append(
|
top_logprobs_nums.append(
|
||||||
request.logprobs if request.logprobs is not None else 0
|
request.logprobs if request.logprobs is not None else 0
|
||||||
)
|
)
|
||||||
|
return_hidden_states.append(request.return_hidden_states)
|
||||||
|
|
||||||
if len(all_requests) == 1:
|
if len(all_requests) == 1:
|
||||||
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
|
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
|
||||||
@@ -599,6 +601,7 @@ def v1_generate_request(
|
|||||||
logprob_start_lens = logprob_start_lens[0]
|
logprob_start_lens = logprob_start_lens[0]
|
||||||
top_logprobs_nums = top_logprobs_nums[0]
|
top_logprobs_nums = top_logprobs_nums[0]
|
||||||
lora_paths = lora_paths[0]
|
lora_paths = lora_paths[0]
|
||||||
|
return_hidden_states = return_hidden_states[0]
|
||||||
else:
|
else:
|
||||||
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
|
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
|
||||||
prompt_kwargs = {"text": prompts}
|
prompt_kwargs = {"text": prompts}
|
||||||
@@ -615,6 +618,7 @@ def v1_generate_request(
|
|||||||
stream=all_requests[0].stream,
|
stream=all_requests[0].stream,
|
||||||
rid=request_ids,
|
rid=request_ids,
|
||||||
lora_path=lora_paths,
|
lora_path=lora_paths,
|
||||||
|
return_hidden_states=return_hidden_states,
|
||||||
bootstrap_host=all_requests[0].bootstrap_host,
|
bootstrap_host=all_requests[0].bootstrap_host,
|
||||||
bootstrap_port=all_requests[0].bootstrap_port,
|
bootstrap_port=all_requests[0].bootstrap_port,
|
||||||
bootstrap_room=all_requests[0].bootstrap_room,
|
bootstrap_room=all_requests[0].bootstrap_room,
|
||||||
@@ -683,6 +687,16 @@ def v1_generate_response(
|
|||||||
else:
|
else:
|
||||||
logprobs = None
|
logprobs = None
|
||||||
|
|
||||||
|
hidden_states = None
|
||||||
|
if isinstance(request, list) and request[idx].return_hidden_states:
|
||||||
|
hidden_states = ret_item["meta_info"].get("hidden_states", None)
|
||||||
|
elif (not isinstance(request, list)) and request.return_hidden_states:
|
||||||
|
hidden_states = ret_item["meta_info"].get("hidden_states", None)
|
||||||
|
if hidden_states is not None:
|
||||||
|
hidden_states = (
|
||||||
|
hidden_states[-1] if hidden_states and len(hidden_states) > 1 else []
|
||||||
|
)
|
||||||
|
|
||||||
finish_reason = ret_item["meta_info"]["finish_reason"]
|
finish_reason = ret_item["meta_info"]["finish_reason"]
|
||||||
|
|
||||||
if to_file:
|
if to_file:
|
||||||
@@ -698,6 +712,8 @@ def v1_generate_response(
|
|||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
if hidden_states is not None:
|
||||||
|
choice_data["hidden_states"] = hidden_states
|
||||||
else:
|
else:
|
||||||
choice_data = CompletionResponseChoice(
|
choice_data = CompletionResponseChoice(
|
||||||
index=idx,
|
index=idx,
|
||||||
@@ -709,6 +725,7 @@ def v1_generate_response(
|
|||||||
if finish_reason and "matched" in finish_reason
|
if finish_reason and "matched" in finish_reason
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
|
hidden_states=hidden_states,
|
||||||
)
|
)
|
||||||
|
|
||||||
choices.append(choice_data)
|
choices.append(choice_data)
|
||||||
@@ -777,6 +794,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|||||||
prompt_tokens = {}
|
prompt_tokens = {}
|
||||||
completion_tokens = {}
|
completion_tokens = {}
|
||||||
cached_tokens = {}
|
cached_tokens = {}
|
||||||
|
hidden_states = {}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async for content in tokenizer_manager.generate_request(
|
async for content in tokenizer_manager.generate_request(
|
||||||
@@ -791,6 +809,9 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|||||||
prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
|
prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
|
||||||
completion_tokens[index] = content["meta_info"]["completion_tokens"]
|
completion_tokens[index] = content["meta_info"]["completion_tokens"]
|
||||||
cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
|
cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
|
||||||
|
hidden_states[index] = content["meta_info"].get(
|
||||||
|
"hidden_states", None
|
||||||
|
) or hidden_states.get(index)
|
||||||
|
|
||||||
if not stream_buffer: # The first chunk
|
if not stream_buffer: # The first chunk
|
||||||
if request.echo:
|
if request.echo:
|
||||||
@@ -873,6 +894,27 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|||||||
n_prev_tokens[index] = n_prev_token
|
n_prev_tokens[index] = n_prev_token
|
||||||
|
|
||||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||||
|
if request.return_hidden_states and hidden_states:
|
||||||
|
for index, choice_hidden_states in hidden_states.items():
|
||||||
|
last_token_hidden_states = (
|
||||||
|
choice_hidden_states[-1]
|
||||||
|
if choice_hidden_states and len(choice_hidden_states) > 1
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
hidden_states_chunk = CompletionStreamResponse(
|
||||||
|
id=content["meta_info"]["id"],
|
||||||
|
created=created,
|
||||||
|
choices=[
|
||||||
|
CompletionResponseStreamChoice(
|
||||||
|
text="",
|
||||||
|
index=index,
|
||||||
|
hidden_states=last_token_hidden_states,
|
||||||
|
finish_reason=None,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
model=request.model,
|
||||||
|
)
|
||||||
|
yield f"data: {hidden_states_chunk.model_dump_json()}\n\n"
|
||||||
if request.stream_options and request.stream_options.include_usage:
|
if request.stream_options and request.stream_options.include_usage:
|
||||||
total_prompt_tokens = sum(
|
total_prompt_tokens = sum(
|
||||||
tokens
|
tokens
|
||||||
@@ -973,6 +1015,7 @@ def v1_chat_generate_request(
|
|||||||
top_logprobs_nums = []
|
top_logprobs_nums = []
|
||||||
modalities_list = []
|
modalities_list = []
|
||||||
lora_paths = []
|
lora_paths = []
|
||||||
|
return_hidden_states = []
|
||||||
|
|
||||||
# NOTE: with openai API, the prompt's logprobs are always not computed
|
# NOTE: with openai API, the prompt's logprobs are always not computed
|
||||||
|
|
||||||
@@ -1215,6 +1258,7 @@ def v1_chat_generate_request(
|
|||||||
image_data_list.append(image_data)
|
image_data_list.append(image_data)
|
||||||
audio_data_list.append(audio_data)
|
audio_data_list.append(audio_data)
|
||||||
modalities_list.append(modalities)
|
modalities_list.append(modalities)
|
||||||
|
return_hidden_states.append(request.return_hidden_states)
|
||||||
if len(all_requests) == 1:
|
if len(all_requests) == 1:
|
||||||
if is_multimodal:
|
if is_multimodal:
|
||||||
# processor will need text input
|
# processor will need text input
|
||||||
@@ -1233,6 +1277,7 @@ def v1_chat_generate_request(
|
|||||||
modalities_list = modalities_list[0]
|
modalities_list = modalities_list[0]
|
||||||
lora_paths = lora_paths[0]
|
lora_paths = lora_paths[0]
|
||||||
request_ids = request_ids[0]
|
request_ids = request_ids[0]
|
||||||
|
return_hidden_states = return_hidden_states[0]
|
||||||
else:
|
else:
|
||||||
if tokenizer_manager.model_config.is_multimodal:
|
if tokenizer_manager.model_config.is_multimodal:
|
||||||
# processor will need text input
|
# processor will need text input
|
||||||
@@ -1259,6 +1304,7 @@ def v1_chat_generate_request(
|
|||||||
bootstrap_host=all_requests[0].bootstrap_host,
|
bootstrap_host=all_requests[0].bootstrap_host,
|
||||||
bootstrap_port=all_requests[0].bootstrap_port,
|
bootstrap_port=all_requests[0].bootstrap_port,
|
||||||
bootstrap_room=all_requests[0].bootstrap_room,
|
bootstrap_room=all_requests[0].bootstrap_room,
|
||||||
|
return_hidden_states=return_hidden_states,
|
||||||
)
|
)
|
||||||
|
|
||||||
return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
|
return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
|
||||||
@@ -1319,6 +1365,20 @@ def v1_chat_generate_response(
|
|||||||
else:
|
else:
|
||||||
choice_logprobs = None
|
choice_logprobs = None
|
||||||
|
|
||||||
|
if isinstance(request, list) and request[idx].return_hidden_states:
|
||||||
|
include_hidden_states = True
|
||||||
|
elif not isinstance(request, list) and request.return_hidden_states:
|
||||||
|
include_hidden_states = True
|
||||||
|
else:
|
||||||
|
include_hidden_states = False
|
||||||
|
if include_hidden_states and ret_item["meta_info"].get("hidden_states", None):
|
||||||
|
hidden_states = ret_item["meta_info"]["hidden_states"]
|
||||||
|
hidden_states = (
|
||||||
|
hidden_states[-1] if hidden_states and len(hidden_states) > 1 else []
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states = None
|
||||||
|
|
||||||
finish_reason = ret_item["meta_info"]["finish_reason"]
|
finish_reason = ret_item["meta_info"]["finish_reason"]
|
||||||
|
|
||||||
tool_calls = None
|
tool_calls = None
|
||||||
@@ -1391,6 +1451,8 @@ def v1_chat_generate_response(
|
|||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
if hidden_states is not None:
|
||||||
|
choice_data["hidden_states"] = hidden_states
|
||||||
else:
|
else:
|
||||||
choice_data = ChatCompletionResponseChoice(
|
choice_data = ChatCompletionResponseChoice(
|
||||||
index=idx,
|
index=idx,
|
||||||
@@ -1407,6 +1469,7 @@ def v1_chat_generate_response(
|
|||||||
if finish_reason and "matched" in finish_reason
|
if finish_reason and "matched" in finish_reason
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
|
hidden_states=hidden_states,
|
||||||
)
|
)
|
||||||
|
|
||||||
choices.append(choice_data)
|
choices.append(choice_data)
|
||||||
@@ -1486,12 +1549,16 @@ async def v1_chat_completions(
|
|||||||
prompt_tokens = {}
|
prompt_tokens = {}
|
||||||
completion_tokens = {}
|
completion_tokens = {}
|
||||||
cached_tokens = {}
|
cached_tokens = {}
|
||||||
|
hidden_states = {}
|
||||||
try:
|
try:
|
||||||
async for content in tokenizer_manager.generate_request(
|
async for content in tokenizer_manager.generate_request(
|
||||||
adapted_request, raw_request
|
adapted_request, raw_request
|
||||||
):
|
):
|
||||||
index = content.get("index", 0)
|
index = content.get("index", 0)
|
||||||
text = content["text"]
|
text = content["text"]
|
||||||
|
hidden_states[index] = content["meta_info"].get(
|
||||||
|
"hidden_states", None
|
||||||
|
) or hidden_states.get(index)
|
||||||
|
|
||||||
is_first = is_firsts.get(index, True)
|
is_first = is_firsts.get(index, True)
|
||||||
stream_buffer = stream_buffers.get(index, "")
|
stream_buffer = stream_buffers.get(index, "")
|
||||||
@@ -1613,6 +1680,7 @@ async def v1_chat_completions(
|
|||||||
if (delta and len(delta) == 0) or not delta:
|
if (delta and len(delta) == 0) or not delta:
|
||||||
stream_buffers[index] = new_stream_buffer
|
stream_buffers[index] = new_stream_buffer
|
||||||
is_firsts[index] = is_first
|
is_firsts[index] = is_first
|
||||||
|
n_prev_tokens[index] = n_prev_token
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if request.tool_choice != "none" and request.tools:
|
if request.tool_choice != "none" and request.tools:
|
||||||
@@ -1702,6 +1770,7 @@ async def v1_chat_completions(
|
|||||||
|
|
||||||
stream_buffers[index] = new_stream_buffer
|
stream_buffers[index] = new_stream_buffer
|
||||||
is_firsts[index] = is_first
|
is_firsts[index] = is_first
|
||||||
|
n_prev_tokens[index] = n_prev_token
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# No tool calls => just treat this as normal text
|
# No tool calls => just treat this as normal text
|
||||||
@@ -1734,6 +1803,7 @@ async def v1_chat_completions(
|
|||||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||||
stream_buffers[index] = new_stream_buffer
|
stream_buffers[index] = new_stream_buffer
|
||||||
is_firsts[index] = is_first
|
is_firsts[index] = is_first
|
||||||
|
n_prev_tokens[index] = n_prev_token
|
||||||
if finish_reason_type == "stop" and request.tool_choice != "none":
|
if finish_reason_type == "stop" and request.tool_choice != "none":
|
||||||
parser = FunctionCallParser(
|
parser = FunctionCallParser(
|
||||||
tools=request.tools,
|
tools=request.tools,
|
||||||
@@ -1769,6 +1839,28 @@ async def v1_chat_completions(
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
usage = None
|
usage = None
|
||||||
|
if request.return_hidden_states and hidden_states:
|
||||||
|
for index, choice_hidden_states in hidden_states.items():
|
||||||
|
last_token_hidden_states = (
|
||||||
|
choice_hidden_states[-1]
|
||||||
|
if choice_hidden_states and len(choice_hidden_states) > 1
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
hidden_states_chunk = ChatCompletionStreamResponse(
|
||||||
|
id=content["meta_info"]["id"],
|
||||||
|
created=created,
|
||||||
|
choices=[
|
||||||
|
ChatCompletionResponseStreamChoice(
|
||||||
|
index=index,
|
||||||
|
delta=DeltaMessage(
|
||||||
|
hidden_states=last_token_hidden_states
|
||||||
|
),
|
||||||
|
finish_reason=finish_reason_type,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
model=request.model,
|
||||||
|
)
|
||||||
|
yield f"data: {hidden_states_chunk.model_dump_json()}\n\n"
|
||||||
final_usage_chunk = ChatCompletionStreamResponse(
|
final_usage_chunk = ChatCompletionStreamResponse(
|
||||||
id=content["meta_info"]["id"],
|
id=content["meta_info"]["id"],
|
||||||
created=created,
|
created=created,
|
||||||
|
|||||||
@@ -16,7 +16,7 @@
|
|||||||
import time
|
import time
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, root_validator
|
from pydantic import BaseModel, Field, model_serializer, root_validator
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
|
|
||||||
|
|
||||||
@@ -182,6 +182,7 @@ class CompletionRequest(BaseModel):
|
|||||||
skip_special_tokens: bool = True
|
skip_special_tokens: bool = True
|
||||||
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
||||||
session_params: Optional[Dict] = None
|
session_params: Optional[Dict] = None
|
||||||
|
return_hidden_states: Optional[bool] = False
|
||||||
|
|
||||||
# For PD disaggregation
|
# For PD disaggregation
|
||||||
bootstrap_host: Optional[str] = None
|
bootstrap_host: Optional[str] = None
|
||||||
@@ -195,6 +196,11 @@ class CompletionResponseChoice(BaseModel):
|
|||||||
logprobs: Optional[LogProbs] = None
|
logprobs: Optional[LogProbs] = None
|
||||||
finish_reason: Literal["stop", "length", "content_filter", "abort"]
|
finish_reason: Literal["stop", "length", "content_filter", "abort"]
|
||||||
matched_stop: Union[None, int, str] = None
|
matched_stop: Union[None, int, str] = None
|
||||||
|
hidden_states: Optional[object] = None
|
||||||
|
|
||||||
|
@model_serializer
|
||||||
|
def _serialize(self):
|
||||||
|
return exclude_if_none(self, ["hidden_states"])
|
||||||
|
|
||||||
|
|
||||||
class CompletionResponse(BaseModel):
|
class CompletionResponse(BaseModel):
|
||||||
@@ -212,6 +218,11 @@ class CompletionResponseStreamChoice(BaseModel):
|
|||||||
logprobs: Optional[LogProbs] = None
|
logprobs: Optional[LogProbs] = None
|
||||||
finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None
|
finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None
|
||||||
matched_stop: Union[None, int, str] = None
|
matched_stop: Union[None, int, str] = None
|
||||||
|
hidden_states: Optional[object] = None
|
||||||
|
|
||||||
|
@model_serializer
|
||||||
|
def _serialize(self):
|
||||||
|
return exclude_if_none(self, ["hidden_states"])
|
||||||
|
|
||||||
|
|
||||||
class CompletionStreamResponse(BaseModel):
|
class CompletionStreamResponse(BaseModel):
|
||||||
@@ -405,6 +416,9 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
bootstrap_port: Optional[int] = None
|
bootstrap_port: Optional[int] = None
|
||||||
bootstrap_room: Optional[int] = None
|
bootstrap_room: Optional[int] = None
|
||||||
|
|
||||||
|
# Hidden States
|
||||||
|
return_hidden_states: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage(BaseModel):
|
class ChatMessage(BaseModel):
|
||||||
role: Optional[str] = None
|
role: Optional[str] = None
|
||||||
@@ -421,6 +435,11 @@ class ChatCompletionResponseChoice(BaseModel):
|
|||||||
"stop", "length", "tool_calls", "content_filter", "function_call", "abort"
|
"stop", "length", "tool_calls", "content_filter", "function_call", "abort"
|
||||||
]
|
]
|
||||||
matched_stop: Union[None, int, str] = None
|
matched_stop: Union[None, int, str] = None
|
||||||
|
hidden_states: Optional[object] = None
|
||||||
|
|
||||||
|
@model_serializer
|
||||||
|
def _serialize(self):
|
||||||
|
return exclude_if_none(self, ["hidden_states"])
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionResponse(BaseModel):
|
class ChatCompletionResponse(BaseModel):
|
||||||
@@ -437,6 +456,11 @@ class DeltaMessage(BaseModel):
|
|||||||
content: Optional[str] = None
|
content: Optional[str] = None
|
||||||
reasoning_content: Optional[str] = None
|
reasoning_content: Optional[str] = None
|
||||||
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
|
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
|
||||||
|
hidden_states: Optional[object] = None
|
||||||
|
|
||||||
|
@model_serializer
|
||||||
|
def _serialize(self):
|
||||||
|
return exclude_if_none(self, ["hidden_states"])
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||||
@@ -513,3 +537,8 @@ class ScoringResponse(BaseModel):
|
|||||||
model: str
|
model: str
|
||||||
usage: Optional[UsageInfo] = None
|
usage: Optional[UsageInfo] = None
|
||||||
object: str = "scoring"
|
object: str = "scoring"
|
||||||
|
|
||||||
|
|
||||||
|
def exclude_if_none(obj, field_names: List[str]):
|
||||||
|
omit_if_none_fields = {k for k, v in obj.model_fields.items() if k in field_names}
|
||||||
|
return {k: v for k, v in obj if k not in omit_if_none_fields or v is not None}
|
||||||
|
|||||||
@@ -215,6 +215,7 @@ class ServerArgs:
|
|||||||
disable_chunked_prefix_cache: bool = False
|
disable_chunked_prefix_cache: bool = False
|
||||||
disable_fast_image_processor: bool = False
|
disable_fast_image_processor: bool = False
|
||||||
warmups: Optional[str] = None
|
warmups: Optional[str] = None
|
||||||
|
enable_return_hidden_states: bool = False
|
||||||
|
|
||||||
# Debug tensor dumps
|
# Debug tensor dumps
|
||||||
debug_tensor_dump_output_folder: Optional[str] = None
|
debug_tensor_dump_output_folder: Optional[str] = None
|
||||||
@@ -1456,6 +1457,12 @@ class ServerArgs:
|
|||||||
default=ServerArgs.debug_tensor_dump_inject,
|
default=ServerArgs.debug_tensor_dump_inject,
|
||||||
help="Inject the outputs from jax as the input of every layer.",
|
help="Inject the outputs from jax as the input of every layer.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-return-hidden-states",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable returning hidden states with responses.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--debug-tensor-dump-prefill-only",
|
"--debug-tensor-dump-prefill-only",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
|||||||
@@ -117,9 +117,7 @@ class EAGLEDraftCudaGraphRunner:
|
|||||||
hidden_states = self.hidden_states[:num_seqs]
|
hidden_states = self.hidden_states[:num_seqs]
|
||||||
|
|
||||||
spec_info = EagleDraftInput(
|
spec_info = EagleDraftInput(
|
||||||
topk_p=topk_p,
|
topk_p=topk_p, topk_index=topk_index, hidden_states=hidden_states
|
||||||
topk_index=topk_index,
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Forward batch
|
# Forward batch
|
||||||
|
|||||||
@@ -290,6 +290,7 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
A tuple of the final logit output of the target model, next tokens accepted,
|
A tuple of the final logit output of the target model, next tokens accepted,
|
||||||
the batch id (used for overlap schedule), and number of accepted tokens.
|
the batch id (used for overlap schedule), and number of accepted tokens.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if batch.forward_mode.is_decode():
|
if batch.forward_mode.is_decode():
|
||||||
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
||||||
spec_info = self.draft(batch)
|
spec_info = self.draft(batch)
|
||||||
@@ -431,10 +432,10 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
batch.out_cache_loc = out_cache_loc
|
batch.out_cache_loc = out_cache_loc
|
||||||
batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
|
batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
|
||||||
spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0)
|
spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0)
|
||||||
|
|
||||||
# Get forward batch
|
|
||||||
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||||
|
batch.return_hidden_states = False
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
|
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
|
||||||
forward_batch = ForwardBatch.init_new(
|
forward_batch = ForwardBatch.init_new(
|
||||||
model_worker_batch, self.draft_model_runner
|
model_worker_batch, self.draft_model_runner
|
||||||
)
|
)
|
||||||
@@ -547,11 +548,13 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
|
|
||||||
def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
|
def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
|
||||||
spec_info.prepare_for_verify(batch, self.page_size)
|
spec_info.prepare_for_verify(batch, self.page_size)
|
||||||
|
batch.return_hidden_states = False
|
||||||
batch.forward_mode = ForwardMode.TARGET_VERIFY
|
batch.forward_mode = ForwardMode.TARGET_VERIFY
|
||||||
batch.spec_info = spec_info
|
batch.spec_info = spec_info
|
||||||
model_worker_batch = batch.get_model_worker_batch(
|
model_worker_batch = batch.get_model_worker_batch(
|
||||||
seq_lens_cpu_cache=spec_info.seq_lens_cpu
|
seq_lens_cpu_cache=spec_info.seq_lens_cpu
|
||||||
)
|
)
|
||||||
|
assert model_worker_batch.capture_hidden_mode == spec_info.capture_hidden_mode
|
||||||
|
|
||||||
if batch.has_grammar:
|
if batch.has_grammar:
|
||||||
retrieve_next_token_cpu = spec_info.retrive_next_token.cpu()
|
retrieve_next_token_cpu = spec_info.retrive_next_token.cpu()
|
||||||
@@ -687,15 +690,18 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
hidden_states: Hidden states from the target model forward
|
hidden_states: Hidden states from the target model forward
|
||||||
next_token_ids: Next token ids generated from the target forward.
|
next_token_ids: Next token ids generated from the target forward.
|
||||||
"""
|
"""
|
||||||
|
# Sometimes we get hidden states produced by CaptureHiddenMode.FULL, so we have to select just the last
|
||||||
batch.spec_info = EagleDraftInput(
|
batch.spec_info = EagleDraftInput(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
verified_id=next_token_ids,
|
verified_id=next_token_ids,
|
||||||
)
|
)
|
||||||
|
batch.return_hidden_states = False
|
||||||
batch.spec_info.prepare_for_extend(batch)
|
batch.spec_info.prepare_for_extend(batch)
|
||||||
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||||
model_worker_batch = batch.get_model_worker_batch(
|
model_worker_batch = batch.get_model_worker_batch(
|
||||||
seq_lens_cpu_cache=seq_lens_cpu
|
seq_lens_cpu_cache=seq_lens_cpu
|
||||||
)
|
)
|
||||||
|
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
|
||||||
forward_batch = ForwardBatch.init_new(
|
forward_batch = ForwardBatch.init_new(
|
||||||
model_worker_batch, self.draft_model_runner
|
model_worker_batch, self.draft_model_runner
|
||||||
)
|
)
|
||||||
@@ -718,7 +724,9 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
batch,
|
batch,
|
||||||
self.speculative_num_steps,
|
self.speculative_num_steps,
|
||||||
)
|
)
|
||||||
|
batch.return_hidden_states = False
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
|
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
|
||||||
forward_batch = ForwardBatch.init_new(
|
forward_batch = ForwardBatch.init_new(
|
||||||
model_worker_batch, self.draft_model_runner
|
model_worker_batch, self.draft_model_runner
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -59,6 +59,7 @@ suites = {
|
|||||||
TestFile("test_openai_adapter.py", 1),
|
TestFile("test_openai_adapter.py", 1),
|
||||||
TestFile("test_openai_function_calling.py", 60),
|
TestFile("test_openai_function_calling.py", 60),
|
||||||
TestFile("test_openai_server.py", 149),
|
TestFile("test_openai_server.py", 149),
|
||||||
|
TestFile("test_openai_server_hidden_states.py", 240),
|
||||||
TestFile("test_penalty.py", 41),
|
TestFile("test_penalty.py", 41),
|
||||||
TestFile("test_page_size.py", 60),
|
TestFile("test_page_size.py", 60),
|
||||||
TestFile("test_pytorch_sampling_backend.py", 66),
|
TestFile("test_pytorch_sampling_backend.py", 66),
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ class TestHiddenState(CustomTestCase):
|
|||||||
model_path=model_path,
|
model_path=model_path,
|
||||||
random_seed=42,
|
random_seed=42,
|
||||||
skip_tokenizer_init=True,
|
skip_tokenizer_init=True,
|
||||||
|
enable_return_hidden_states=True,
|
||||||
)
|
)
|
||||||
outputs = engine.generate(
|
outputs = engine.generate(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
@@ -96,6 +97,7 @@ class TestHiddenState(CustomTestCase):
|
|||||||
model_path=model_path,
|
model_path=model_path,
|
||||||
random_seed=42,
|
random_seed=42,
|
||||||
skip_tokenizer_init=True,
|
skip_tokenizer_init=True,
|
||||||
|
enable_return_hidden_states=True,
|
||||||
)
|
)
|
||||||
outputs_completion_first_round = engine.generate(
|
outputs_completion_first_round = engine.generate(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
|
|||||||
@@ -381,12 +381,14 @@ class TestGenerateReqInputNormalization(CustomTestCase):
|
|||||||
logprob_start_len=[10, 5],
|
logprob_start_len=[10, 5],
|
||||||
top_logprobs_num=[5, 3],
|
top_logprobs_num=[5, 3],
|
||||||
token_ids_logprob=[[7, 8, 9], [4, 5, 6]],
|
token_ids_logprob=[[7, 8, 9], [4, 5, 6]],
|
||||||
|
return_hidden_states=[False, False, True],
|
||||||
)
|
)
|
||||||
req.normalize_batch_and_arguments()
|
req.normalize_batch_and_arguments()
|
||||||
self.assertEqual(req.return_logprob, [True, False])
|
self.assertEqual(req.return_logprob, [True, False])
|
||||||
self.assertEqual(req.logprob_start_len, [10, 5])
|
self.assertEqual(req.logprob_start_len, [10, 5])
|
||||||
self.assertEqual(req.top_logprobs_num, [5, 3])
|
self.assertEqual(req.top_logprobs_num, [5, 3])
|
||||||
self.assertEqual(req.token_ids_logprob, [[7, 8, 9], [4, 5, 6]])
|
self.assertEqual(req.token_ids_logprob, [[7, 8, 9], [4, 5, 6]])
|
||||||
|
self.assertEqual(req.return_hidden_states, [False, False, True])
|
||||||
|
|
||||||
def test_custom_logit_processor_normalization(self):
|
def test_custom_logit_processor_normalization(self):
|
||||||
"""Test normalization of custom_logit_processor."""
|
"""Test normalization of custom_logit_processor."""
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
"""
|
"""
|
||||||
python3 -m unittest test_openai_server.TestOpenAIServer.test_batch
|
python3 -m unittest test_openai_server.TestOpenAIServer.test_batch
|
||||||
python3 -m unittest test_openai_server.TestOpenAIServer.test_completion
|
python3 -m unittest test_openai_server.TestOpenAIServer.test_completion
|
||||||
|
python3 -m unittest test_openai_server.TestOpenAIServer.test_completion_stream
|
||||||
|
python3 -m unittest test_openai_server.TestOpenAIServer.test_chat_completion
|
||||||
|
python3 -m unittest test_openai_server.TestOpenAIServer.test_chat_completion_stream
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
@@ -9,6 +11,7 @@ import re
|
|||||||
import time
|
import time
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import openai
|
import openai
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
@@ -137,27 +140,29 @@ class TestOpenAIServer(CustomTestCase):
|
|||||||
for response in generator:
|
for response in generator:
|
||||||
usage = response.usage
|
usage = response.usage
|
||||||
if usage is not None:
|
if usage is not None:
|
||||||
assert usage.prompt_tokens > 0
|
assert usage.prompt_tokens > 0, f"usage.prompt_tokens was zero"
|
||||||
assert usage.completion_tokens > 0
|
assert usage.completion_tokens > 0, f"usage.completion_tokens was zero"
|
||||||
assert usage.total_tokens > 0
|
assert usage.total_tokens > 0, f"usage.total_tokens was zero"
|
||||||
continue
|
continue
|
||||||
|
|
||||||
index = response.choices[0].index
|
index = response.choices[0].index
|
||||||
is_first = is_firsts.get(index, True)
|
is_first = is_firsts.get(index, True)
|
||||||
|
|
||||||
if logprobs:
|
if logprobs:
|
||||||
assert response.choices[0].logprobs
|
assert response.choices[0].logprobs, f"no logprobs in response"
|
||||||
assert isinstance(response.choices[0].logprobs.tokens[0], str)
|
assert isinstance(
|
||||||
|
response.choices[0].logprobs.tokens[0], str
|
||||||
|
), f"{response.choices[0].logprobs.tokens[0]} is not a string"
|
||||||
if not (is_first and echo):
|
if not (is_first and echo):
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
response.choices[0].logprobs.top_logprobs[0], dict
|
response.choices[0].logprobs.top_logprobs[0], dict
|
||||||
)
|
), f"top_logprobs was not a dictionary"
|
||||||
ret_num_top_logprobs = len(
|
ret_num_top_logprobs = len(
|
||||||
response.choices[0].logprobs.top_logprobs[0]
|
response.choices[0].logprobs.top_logprobs[0]
|
||||||
)
|
)
|
||||||
# FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some output id maps to the same output token and duplicate in the map
|
# FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some output id maps to the same output token and duplicate in the map
|
||||||
# assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
|
# assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
|
||||||
assert ret_num_top_logprobs > 0
|
assert ret_num_top_logprobs > 0, f"ret_num_top_logprobs was 0"
|
||||||
|
|
||||||
if is_first:
|
if is_first:
|
||||||
if echo:
|
if echo:
|
||||||
@@ -165,8 +170,8 @@ class TestOpenAIServer(CustomTestCase):
|
|||||||
prompt
|
prompt
|
||||||
), f"{response.choices[0].text} and all args {echo} {logprobs} {token_input} {is_first}"
|
), f"{response.choices[0].text} and all args {echo} {logprobs} {token_input} {is_first}"
|
||||||
is_firsts[index] = False
|
is_firsts[index] = False
|
||||||
assert response.id
|
assert response.id, f"no id in response"
|
||||||
assert response.created
|
assert response.created, f"no created in response"
|
||||||
|
|
||||||
for index in [i for i in range(parallel_sample_num * num_choices)]:
|
for index in [i for i in range(parallel_sample_num * num_choices)]:
|
||||||
assert not is_firsts.get(
|
assert not is_firsts.get(
|
||||||
@@ -231,27 +236,29 @@ class TestOpenAIServer(CustomTestCase):
|
|||||||
for response in generator:
|
for response in generator:
|
||||||
usage = response.usage
|
usage = response.usage
|
||||||
if usage is not None:
|
if usage is not None:
|
||||||
assert usage.prompt_tokens > 0
|
assert usage.prompt_tokens > 0, f"usage.prompt_tokens was zero"
|
||||||
assert usage.completion_tokens > 0
|
assert usage.completion_tokens > 0, f"usage.completion_tokens was zero"
|
||||||
assert usage.total_tokens > 0
|
assert usage.total_tokens > 0, f"usage.total_tokens was zero"
|
||||||
continue
|
continue
|
||||||
|
|
||||||
index = response.choices[0].index
|
index = response.choices[0].index
|
||||||
data = response.choices[0].delta
|
data = response.choices[0].delta
|
||||||
|
|
||||||
if is_firsts.get(index, True):
|
if is_firsts.get(index, True):
|
||||||
assert data.role == "assistant"
|
assert (
|
||||||
|
data.role == "assistant"
|
||||||
|
), f"data.role was not 'assistant' for first chunk"
|
||||||
is_firsts[index] = False
|
is_firsts[index] = False
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if logprobs:
|
if logprobs:
|
||||||
assert response.choices[0].logprobs
|
assert response.choices[0].logprobs, f"logprobs was not returned"
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
response.choices[0].logprobs.content[0].top_logprobs[0].token, str
|
response.choices[0].logprobs.content[0].top_logprobs[0].token, str
|
||||||
)
|
), f"top_logprobs token was not a string"
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
response.choices[0].logprobs.content[0].top_logprobs, list
|
response.choices[0].logprobs.content[0].top_logprobs, list
|
||||||
)
|
), f"top_logprobs was not a list"
|
||||||
ret_num_top_logprobs = len(
|
ret_num_top_logprobs = len(
|
||||||
response.choices[0].logprobs.content[0].top_logprobs
|
response.choices[0].logprobs.content[0].top_logprobs
|
||||||
)
|
)
|
||||||
|
|||||||
356
test/srt/test_openai_server_hidden_states.py
Normal file
356
test/srt/test_openai_server_hidden_states.py
Normal file
@@ -0,0 +1,356 @@
|
|||||||
|
import json
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
import unittest
|
||||||
|
from abc import ABC
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import openai
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||||
|
from sglang.srt.utils import kill_process_tree
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
|
||||||
|
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
|
||||||
|
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST,
|
||||||
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
CustomTestCase,
|
||||||
|
popen_launch_server,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseTestOpenAIServerWithHiddenStates(ABC):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.return_hidden_states = [False, True]
|
||||||
|
cls.use_list_input = [True, False]
|
||||||
|
cls.parallel_sample_nums = [1, 2]
|
||||||
|
|
||||||
|
def test_completion(self):
|
||||||
|
for return_hidden_states in self.return_hidden_states:
|
||||||
|
for use_list_input in self.use_list_input:
|
||||||
|
for parallel_sample_num in self.parallel_sample_nums:
|
||||||
|
self.run_completion(
|
||||||
|
use_list_input,
|
||||||
|
parallel_sample_num,
|
||||||
|
return_hidden_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_completion_stream(self):
|
||||||
|
# parallel sampling and list input are not supported in streaming mode
|
||||||
|
for return_hidden_states in self.return_hidden_states:
|
||||||
|
for use_list_input in self.use_list_input:
|
||||||
|
for parallel_sample_num in self.parallel_sample_nums:
|
||||||
|
self.run_completion_stream(
|
||||||
|
use_list_input,
|
||||||
|
parallel_sample_num,
|
||||||
|
return_hidden_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_chat_completion(self):
|
||||||
|
for return_hidden_states in self.return_hidden_states:
|
||||||
|
for (
|
||||||
|
parallel_sample_num
|
||||||
|
) in (
|
||||||
|
self.parallel_sample_nums
|
||||||
|
): # parallel sample num 2 breaks in the adapter with a 400 for EAGLE
|
||||||
|
self.run_chat_completion(parallel_sample_num, return_hidden_states)
|
||||||
|
|
||||||
|
def test_chat_completion_stream(self):
|
||||||
|
for return_hidden_states in self.return_hidden_states:
|
||||||
|
for (
|
||||||
|
parallel_sample_num
|
||||||
|
) in (
|
||||||
|
self.parallel_sample_nums
|
||||||
|
): # parallel sample num > 1 breaks in the adapter with a 400 for EAGLE
|
||||||
|
self.run_chat_completion_stream(
|
||||||
|
parallel_sample_num, return_hidden_states
|
||||||
|
)
|
||||||
|
|
||||||
|
def run_completion(
|
||||||
|
self,
|
||||||
|
use_list_input,
|
||||||
|
parallel_sample_num,
|
||||||
|
return_hidden_states,
|
||||||
|
):
|
||||||
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
|
prompt = "The capital of France is"
|
||||||
|
prompt_input = prompt
|
||||||
|
|
||||||
|
if use_list_input:
|
||||||
|
prompt_arg = [prompt_input, prompt_input]
|
||||||
|
num_choices = len(prompt_arg)
|
||||||
|
else:
|
||||||
|
prompt_arg = prompt_input
|
||||||
|
num_choices = 1
|
||||||
|
|
||||||
|
response = client.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
prompt=prompt_arg,
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=32,
|
||||||
|
n=parallel_sample_num,
|
||||||
|
extra_body=dict(return_hidden_states=return_hidden_states),
|
||||||
|
)
|
||||||
|
|
||||||
|
for choice in response.choices:
|
||||||
|
assert hasattr(choice, "hidden_states") == return_hidden_states
|
||||||
|
if return_hidden_states:
|
||||||
|
assert choice.hidden_states is not None, "hidden_states was None"
|
||||||
|
|
||||||
|
def run_completion_stream(
|
||||||
|
self,
|
||||||
|
use_list_input,
|
||||||
|
parallel_sample_num,
|
||||||
|
return_hidden_states,
|
||||||
|
):
|
||||||
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
|
prompt = "The capital of France is"
|
||||||
|
prompt_input = prompt
|
||||||
|
num_prompt_tokens = len(self.tokenizer.encode(prompt))
|
||||||
|
|
||||||
|
if use_list_input:
|
||||||
|
prompt_arg = [prompt_input, prompt_input]
|
||||||
|
num_choices = len(prompt_arg)
|
||||||
|
num_prompt_tokens *= 2
|
||||||
|
else:
|
||||||
|
prompt_arg = prompt_input
|
||||||
|
num_choices = 1
|
||||||
|
|
||||||
|
generator = client.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
prompt=prompt_arg,
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=32,
|
||||||
|
stream=True,
|
||||||
|
stream_options={"include_usage": True},
|
||||||
|
n=parallel_sample_num,
|
||||||
|
extra_body=dict(return_hidden_states=return_hidden_states),
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states_list = []
|
||||||
|
for response in generator:
|
||||||
|
usage = response.usage
|
||||||
|
for choice in response.choices:
|
||||||
|
if hasattr(choice, "hidden_states"):
|
||||||
|
assert return_hidden_states
|
||||||
|
assert choice.hidden_states is not None
|
||||||
|
hidden_states_list.append(choice.hidden_states)
|
||||||
|
|
||||||
|
if return_hidden_states:
|
||||||
|
assert (
|
||||||
|
len(hidden_states_list) == parallel_sample_num * num_choices
|
||||||
|
), f"Expected {parallel_sample_num * num_choices} hidden states, got {len(hidden_states_list)}"
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
hidden_states_list == []
|
||||||
|
), "hidden_states were returned and should not have been"
|
||||||
|
|
||||||
|
def run_chat_completion(self, parallel_sample_num, return_hidden_states):
|
||||||
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What is the capital of France? Answer in a few words.",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
temperature=0,
|
||||||
|
n=parallel_sample_num,
|
||||||
|
extra_body=dict(return_hidden_states=return_hidden_states),
|
||||||
|
)
|
||||||
|
|
||||||
|
for choice in response.choices:
|
||||||
|
assert hasattr(choice, "hidden_states") == return_hidden_states
|
||||||
|
if return_hidden_states:
|
||||||
|
assert choice.hidden_states is not None, "hidden_states was None"
|
||||||
|
|
||||||
|
def run_chat_completion_stream(
|
||||||
|
self, parallel_sample_num=1, return_hidden_states=False
|
||||||
|
):
|
||||||
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
|
generator = client.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||||
|
{"role": "user", "content": "What is the capital of France?"},
|
||||||
|
],
|
||||||
|
temperature=0,
|
||||||
|
stream=True,
|
||||||
|
stream_options={"include_usage": True},
|
||||||
|
n=parallel_sample_num,
|
||||||
|
extra_body=dict(return_hidden_states=return_hidden_states),
|
||||||
|
)
|
||||||
|
|
||||||
|
is_firsts = {}
|
||||||
|
hidden_states_list = []
|
||||||
|
|
||||||
|
for response in generator:
|
||||||
|
for choice in response.choices:
|
||||||
|
if hasattr(choice.delta, "hidden_states"):
|
||||||
|
assert return_hidden_states
|
||||||
|
assert choice.delta.hidden_states is not None
|
||||||
|
hidden_states_list.append(choice.delta.hidden_states)
|
||||||
|
|
||||||
|
if return_hidden_states:
|
||||||
|
assert (
|
||||||
|
len(hidden_states_list) == parallel_sample_num
|
||||||
|
), f"Expected {parallel_sample_num} hidden states, got {len(hidden_states_list)}"
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
hidden_states_list == []
|
||||||
|
), "hidden_states were returned and should not have been"
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenAIServerWithHiddenStatesEnabled(
|
||||||
|
CustomTestCase, BaseTestOpenAIServerWithHiddenStates
|
||||||
|
):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.api_key = "sk-123456"
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
api_key=cls.api_key,
|
||||||
|
other_args=["--enable-return-hidden-states"],
|
||||||
|
)
|
||||||
|
cls.base_url += "/v1"
|
||||||
|
cls.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
|
||||||
|
cls.return_hidden_states = [False, True]
|
||||||
|
cls.use_list_input = [True, False]
|
||||||
|
cls.parallel_sample_nums = [1, 2]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_process_tree(cls.process.pid)
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenAIServerWithHiddenStatesEnabledAndCUDAGraphDisabled(
|
||||||
|
CustomTestCase, BaseTestOpenAIServerWithHiddenStates
|
||||||
|
):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.api_key = "sk-123456"
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
api_key=cls.api_key,
|
||||||
|
other_args=["--enable-return-hidden-states", "--disable-cuda-graph"],
|
||||||
|
)
|
||||||
|
cls.base_url += "/v1"
|
||||||
|
cls.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
|
||||||
|
cls.return_hidden_states = [False, True]
|
||||||
|
cls.use_list_input = [True, False]
|
||||||
|
cls.parallel_sample_nums = [1]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_process_tree(cls.process.pid)
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenAIServerWithEAGLEAndHiddenStatesEnabled(
|
||||||
|
CustomTestCase, BaseTestOpenAIServerWithHiddenStates
|
||||||
|
):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.api_key = "sk-123456"
|
||||||
|
cls.speculative_draft_model = DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
|
||||||
|
cls.speculative_algorithm = "EAGLE"
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=[
|
||||||
|
"--speculative-algorithm",
|
||||||
|
"EAGLE",
|
||||||
|
"--speculative-draft-model-path",
|
||||||
|
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
|
||||||
|
"--speculative-num-steps",
|
||||||
|
5,
|
||||||
|
"--speculative-eagle-topk",
|
||||||
|
8,
|
||||||
|
"--speculative-num-draft-tokens",
|
||||||
|
64,
|
||||||
|
"--mem-fraction-static",
|
||||||
|
0.7,
|
||||||
|
"--chunked-prefill-size",
|
||||||
|
128,
|
||||||
|
"--max-running-requests",
|
||||||
|
8,
|
||||||
|
"--enable-return-hidden-states",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
cls.base_url += "/v1"
|
||||||
|
cls.tokenizer = get_tokenizer(DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST)
|
||||||
|
cls.return_hidden_states = [False, True]
|
||||||
|
cls.use_list_input = [True, False]
|
||||||
|
cls.parallel_sample_nums = [1]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_process_tree(cls.process.pid)
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenAIServerWithEAGLE3AndHiddenStatesEnabled(
|
||||||
|
CustomTestCase, BaseTestOpenAIServerWithHiddenStates
|
||||||
|
):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.api_key = "sk-123456"
|
||||||
|
cls.speculative_algorithm = "EAGLE3"
|
||||||
|
cls.speculative_draft_model = "jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B"
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=[
|
||||||
|
"--speculative-algorithm",
|
||||||
|
cls.speculative_algorithm,
|
||||||
|
"--speculative-draft-model-path",
|
||||||
|
cls.speculative_draft_model,
|
||||||
|
"--speculative-num-steps",
|
||||||
|
5,
|
||||||
|
"--speculative-eagle-topk",
|
||||||
|
16,
|
||||||
|
"--speculative-num-draft-tokens",
|
||||||
|
64,
|
||||||
|
"--mem-fraction-static",
|
||||||
|
0.7,
|
||||||
|
"--chunked-prefill-size",
|
||||||
|
128,
|
||||||
|
"--max-running-requests",
|
||||||
|
8,
|
||||||
|
"--dtype",
|
||||||
|
"float16",
|
||||||
|
"--enable-return-hidden-states",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
cls.base_url += "/v1"
|
||||||
|
cls.tokenizer = get_tokenizer(cls.model)
|
||||||
|
cls.return_hidden_states = [False, True]
|
||||||
|
cls.use_list_input = [True, False]
|
||||||
|
cls.parallel_sample_nums = [1]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_process_tree(cls.process.pid)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user