Refine OpenAI serving entrypoint to remove batch requests (#7372)

Signed-off-by: Xinyuan Tong <justinning0323@outlook.com>
Co-authored-by: Chang Su <csu272@usc.edu>
This commit is contained in:
Xinyuan Tong
2025-06-20 14:33:43 -07:00
committed by GitHub
parent 794be55af2
commit 0998808009
8 changed files with 488 additions and 645 deletions

View File

@@ -20,7 +20,7 @@ import logging
import os
from enum import auto
from sglang.srt.openai_api.protocol import ChatCompletionRequest
from sglang.srt.entrypoints.openai.protocol import CompletionRequest
logger = logging.getLogger(__name__)
completion_template_name = None
@@ -116,7 +116,7 @@ def is_completion_template_defined() -> bool:
return completion_template_name is not None
def generate_completion_prompt_from_request(request: ChatCompletionRequest) -> str:
def generate_completion_prompt_from_request(request: CompletionRequest) -> str:
global completion_template_name
if request.suffix == "":
return request.prompt

View File

@@ -2,7 +2,7 @@ import json
import logging
import uuid
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, Optional, Union
from fastapi import Request
from fastapi.responses import ORJSONResponse, StreamingResponse
@@ -37,7 +37,7 @@ class OpenAIServingBase(ABC):
# Convert to internal format
adapted_request, processed_request = self._convert_to_internal_request(
request, self._generate_request_id_base(request)
request
)
# Note(Xinyuan): raw_request below is only used for detecting the connection of the client
@@ -74,10 +74,7 @@ class OpenAIServingBase(ABC):
def _convert_to_internal_request(
self,
request: OpenAIServingRequest,
request_id: str,
) -> tuple[
GenerateReqInput, Union[OpenAIServingRequest, List[OpenAIServingRequest]]
]:
) -> tuple[GenerateReqInput, OpenAIServingRequest]:
"""Convert OpenAI request to internal format"""
pass

View File

@@ -3,7 +3,7 @@ import json
import logging
import time
import uuid
from typing import Any, Dict, List, Optional, Union
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
from fastapi import Request
from fastapi.responses import StreamingResponse
@@ -52,137 +52,56 @@ class OpenAIServingChat(OpenAIServingBase):
def _request_id_prefix(self) -> str:
return "chatcmpl-"
def _validate_request(self, request: ChatCompletionRequest) -> Optional[str]:
"""Validate chat messages format and content"""
if not (messages := request.messages):
return "Messages cannot be empty"
# Check for alternating user/assistant pattern (optional validation)
roles = [msg.role for msg in messages]
# First message should typically be from user or system
if roles[0] not in ["user", "system"]:
return "First message should be from 'user' or 'system'"
# Check for consecutive assistant messages (which might indicate an error)
for i in range(1, len(roles)):
if roles[i] == "assistant" and roles[i - 1] == "assistant":
# This is actually allowed in some cases, so just warn
pass
# Validate message content
for i, msg in enumerate(messages):
if msg.role == "user":
if not msg.content:
return f"User message at index {i} has no content"
elif msg.role == "assistant":
# Assistant messages can have no content if they have tool_calls
if not msg.content and not getattr(msg, "tool_calls", None):
return (
f"Assistant message at index {i} has no content or tool calls"
)
return None
def _convert_to_internal_request(
self,
all_requests: List[ChatCompletionRequest],
request_ids: List[str],
) -> tuple[
GenerateReqInput, Union[ChatCompletionRequest, List[ChatCompletionRequest]]
]:
request: ChatCompletionRequest,
) -> tuple[GenerateReqInput, ChatCompletionRequest]:
"""Convert OpenAI chat completion request to internal format"""
input_ids = []
prompts = []
sampling_params_list = []
image_data_list = []
audio_data_list = []
return_logprobs = []
logprob_start_lens = []
top_logprobs_nums = []
modalities_list = []
lora_paths = []
is_multimodal = self.tokenizer_manager.model_config.is_multimodal
for request in all_requests:
# Process messages and apply chat template
(
prompt,
prompt_ids,
image_data,
audio_data,
modalities,
stop,
tool_call_constraint,
) = self._process_messages(request, is_multimodal)
# Process messages and apply chat template
(
prompt,
prompt_ids,
image_data,
audio_data,
modalities,
stop,
tool_call_constraint,
) = self._process_messages(request, is_multimodal)
input_ids.append(prompt_ids)
prompts.append(prompt)
return_logprobs.append(request.logprobs)
logprob_start_lens.append(-1)
top_logprobs_nums.append(request.top_logprobs or 0)
lora_paths.append(request.lora_path)
# Build sampling parameters
sampling_params = self._build_sampling_params(
request, stop, tool_call_constraint
)
sampling_params_list.append(sampling_params)
image_data_list.append(image_data)
audio_data_list.append(audio_data)
modalities_list.append(modalities)
# Build sampling parameters
sampling_params = self._build_sampling_params(
request, stop, tool_call_constraint
)
# Handle single vs multiple requests
if len(all_requests) == 1:
if is_multimodal:
prompt_kwargs = {"text": prompts[0]}
else:
if isinstance(input_ids[0], str):
prompt_kwargs = {"text": input_ids[0]}
else:
prompt_kwargs = {"input_ids": input_ids[0]}
sampling_params_list = sampling_params_list[0]
image_data_list = image_data_list[0]
audio_data_list = audio_data_list[0]
return_logprobs = return_logprobs[0]
logprob_start_lens = logprob_start_lens[0]
top_logprobs_nums = top_logprobs_nums[0]
modalities_list = modalities_list[0]
lora_paths = lora_paths[0]
request_ids = request_ids[0]
if is_multimodal:
prompt_kwargs = {"text": prompt}
else:
if is_multimodal:
prompt_kwargs = {"text": prompts}
if isinstance(prompt_ids, str):
prompt_kwargs = {"text": prompt_ids}
else:
if isinstance(input_ids[0], str):
prompt_kwargs = {"text": input_ids}
else:
prompt_kwargs = {"input_ids": input_ids}
prompt_kwargs = {"input_ids": prompt_ids}
adapted_request = GenerateReqInput(
**prompt_kwargs,
image_data=image_data_list,
audio_data=audio_data_list,
sampling_params=sampling_params_list,
return_logprob=return_logprobs,
logprob_start_len=logprob_start_lens,
top_logprobs_num=top_logprobs_nums,
stream=all_requests[0].stream,
image_data=image_data,
audio_data=audio_data,
sampling_params=sampling_params,
return_logprob=request.logprobs,
logprob_start_len=-1,
top_logprobs_num=request.top_logprobs or 0,
stream=request.stream,
return_text_in_logprobs=True,
rid=request_ids,
modalities=modalities_list,
lora_path=lora_paths,
bootstrap_host=all_requests[0].bootstrap_host,
bootstrap_port=all_requests[0].bootstrap_port,
bootstrap_room=all_requests[0].bootstrap_room,
modalities=modalities,
lora_path=request.lora_path,
bootstrap_host=request.bootstrap_host,
bootstrap_port=request.bootstrap_port,
bootstrap_room=request.bootstrap_room,
)
return adapted_request, (
all_requests if len(all_requests) > 1 else all_requests[0]
)
return adapted_request, request
def _process_messages(
self, request: ChatCompletionRequest, is_multimodal: bool
@@ -457,55 +376,138 @@ class OpenAIServingChat(OpenAIServingBase):
raw_request: Request,
) -> StreamingResponse:
"""Handle streaming chat completion request"""
return StreamingResponse(
self._generate_chat_stream(adapted_request, request, raw_request),
media_type="text/event-stream",
background=self.tokenizer_manager.create_abort_task(adapted_request),
)
async def generate_stream_resp():
parser_dict = {}
reasoning_parser_dict = {}
tool_call_first = True
is_firsts = {}
stream_buffers = {}
n_prev_tokens = {}
prompt_tokens = {}
completion_tokens = {}
cached_tokens = {}
async def _generate_chat_stream(
self,
adapted_request: GenerateReqInput,
request: ChatCompletionRequest,
raw_request: Request,
) -> AsyncGenerator[str, None]:
"""Generate streaming chat completion response"""
# Parsers for tool calls and reasoning
parser_dict = {}
reasoning_parser_dict = {}
try:
async for content in self.tokenizer_manager.generate_request(
adapted_request, raw_request
):
index = content.get("index", 0)
# State tracking for streaming
is_firsts = {}
stream_buffers = {}
n_prev_tokens = {}
is_first = is_firsts.get(index, True)
stream_buffer = stream_buffers.get(index, "")
n_prev_token = n_prev_tokens.get(index, 0)
# Usage tracking
prompt_tokens = {}
completion_tokens = {}
cached_tokens = {}
prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
completion_tokens[index] = content["meta_info"]["completion_tokens"]
cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
try:
async for content in self.tokenizer_manager.generate_request(
adapted_request, raw_request
):
index = content.get("index", 0)
# Handle logprobs
choice_logprobs = None
if request.logprobs:
choice_logprobs = self._process_streaming_logprobs(
content, n_prev_token
)
n_prev_token = len(
content["meta_info"]["output_token_logprobs"]
)
prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
completion_tokens[index] = content["meta_info"]["completion_tokens"]
cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
finish_reason = content["meta_info"]["finish_reason"]
finish_reason_type = (
finish_reason["type"] if finish_reason else None
# Handle logprobs
choice_logprobs = None
if request.logprobs:
choice_logprobs = self._process_streaming_logprobs(
content, n_prev_tokens.get(index, 0)
)
n_prev_tokens[index] = len(
content["meta_info"]["output_token_logprobs"]
)
# First chunk with role
if is_first:
is_first = False
delta = DeltaMessage(role="assistant")
finish_reason = content["meta_info"]["finish_reason"]
finish_reason_type = finish_reason["type"] if finish_reason else None
# First chunk with role
if is_firsts.get(index, True):
is_firsts[index] = False
delta = DeltaMessage(role="assistant", content="")
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=delta,
finish_reason=finish_reason_type,
matched_stop=(
finish_reason["matched"]
if finish_reason and "matched" in finish_reason
else None
),
logprobs=choice_logprobs,
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
created=int(time.time()),
choices=[choice_data],
model=request.model,
)
yield f"data: {chunk.model_dump_json()}\n\n"
# Process content delta
stream_buffer = stream_buffers.get(index, "")
delta = content["text"][len(stream_buffer) :]
stream_buffers[index] = stream_buffer + delta
# Handle reasoning content
enable_thinking = getattr(request, "chat_template_kwargs", {}).get(
"enable_thinking", True
)
if (
self.tokenizer_manager.server_args.reasoning_parser
and request.separate_reasoning
and enable_thinking
):
reasoning_text, delta = self._process_reasoning_stream(
index, delta, reasoning_parser_dict, content, request
)
if reasoning_text:
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=delta,
delta=DeltaMessage(reasoning_content=reasoning_text),
finish_reason=finish_reason_type,
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
created=int(time.time()),
choices=[choice_data],
model=request.model,
)
yield f"data: {chunk.model_dump_json()}\n\n"
if not delta:
continue
# Handle tool calls
if request.tool_choice != "none" and request.tools:
async for chunk in self._process_tool_call_stream(
index,
delta,
parser_dict,
content,
request,
finish_reason_type,
):
yield chunk
else:
# Regular content
if delta or not (
request.stream_options and request.stream_options.include_usage
):
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(content=delta if delta else None),
finish_reason=(
None
if request.stream_options
and request.stream_options.include_usage
else finish_reason_type
),
matched_stop=(
finish_reason["matched"]
if finish_reason and "matched" in finish_reason
@@ -521,121 +523,49 @@ class OpenAIServingChat(OpenAIServingBase):
)
yield f"data: {chunk.model_dump_json()}\n\n"
# Process content delta
delta = content["text"][len(stream_buffer) :]
new_stream_buffer = stream_buffer + delta
# Handle reasoning content
enable_thinking = getattr(request, "chat_template_kwargs", {}).get(
"enable_thinking", True
# Final chunk with finish_reason
finish_reason_chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
created=int(time.time()),
choices=[
ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(),
finish_reason=finish_reason_type,
matched_stop=(
finish_reason["matched"]
if finish_reason and "matched" in finish_reason
else None
),
)
if (
self.tokenizer_manager.server_args.reasoning_parser
and request.separate_reasoning
and enable_thinking
):
reasoning_text, delta = self._process_reasoning_stream(
index, delta, reasoning_parser_dict, content, request
)
if reasoning_text:
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(reasoning_content=reasoning_text),
finish_reason=finish_reason_type,
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
created=int(time.time()),
choices=[choice_data],
model=request.model,
)
yield f"data: {chunk.model_dump_json()}\n\n"
],
model=request.model,
usage=None,
)
yield f"data: {finish_reason_chunk.model_dump_json()}\n\n"
if not delta:
stream_buffers[index] = new_stream_buffer
is_firsts[index] = is_first
n_prev_tokens[index] = n_prev_token
continue
# Handle tool calls
if request.tool_choice != "none" and request.tools:
async for chunk in self._process_tool_call_stream(
index,
delta,
parser_dict,
content,
request,
finish_reason_type,
):
yield chunk
else:
# Regular content
if delta or not (
request.stream_options
and request.stream_options.include_usage
):
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(content=delta if delta else None),
finish_reason=(
None
if request.stream_options
and request.stream_options.include_usage
else finish_reason_type
),
matched_stop=(
finish_reason["matched"]
if finish_reason and "matched" in finish_reason
else None
),
logprobs=choice_logprobs,
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
created=int(time.time()),
choices=[choice_data],
model=request.model,
)
yield f"data: {chunk.model_dump_json()}\n\n"
stream_buffers[index] = new_stream_buffer
is_firsts[index] = is_first
n_prev_tokens[index] = n_prev_token
# Final chunk with usage
if request.stream_options and request.stream_options.include_usage:
usage = self._calculate_streaming_usage_base(
prompt_tokens, completion_tokens, cached_tokens, request.n
)
else:
usage = None
final_chunk = ChatCompletionStreamResponse(
# Additional usage chunk
if request.stream_options and request.stream_options.include_usage:
usage = self._calculate_streaming_usage_base(
prompt_tokens,
completion_tokens,
cached_tokens,
request.n,
)
usage_chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
created=int(time.time()),
choices=[
ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(),
finish_reason=finish_reason_type,
)
],
choices=[], # Empty choices array as per OpenAI spec
model=request.model,
usage=usage,
)
yield f"data: {final_chunk.model_dump_json()}\n\n"
yield f"data: {usage_chunk.model_dump_json()}\n\n"
except Exception as e:
error = self.create_streaming_error_response(str(e))
yield f"data: {error}\n\n"
except Exception as e:
error = self.create_streaming_error_response(str(e))
yield f"data: {error}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(
generate_stream_resp(),
media_type="text/event-stream",
background=self.tokenizer_manager.create_abort_task(adapted_request),
)
yield "data: [DONE]\n\n"
async def _handle_non_streaming_request(
self,
@@ -658,9 +588,6 @@ class OpenAIServingChat(OpenAIServingBase):
request,
ret,
int(time.time()),
cache_report=self.tokenizer_manager.server_args.enable_cache_report,
tool_call_parser=self.tokenizer_manager.server_args.tool_call_parser,
reasoning_parser=self.tokenizer_manager.server_args.reasoning_parser,
)
return response
@@ -670,9 +597,6 @@ class OpenAIServingChat(OpenAIServingBase):
request: ChatCompletionRequest,
ret: List[Dict[str, Any]],
created: int,
cache_report: bool = False,
tool_call_parser: Optional[str] = None,
reasoning_parser: Optional[str] = None,
) -> ChatCompletionResponse:
"""Build chat completion response from generation results"""
choices = []
@@ -691,6 +615,7 @@ class OpenAIServingChat(OpenAIServingBase):
enable_thinking = getattr(request, "chat_template_kwargs", {}).get(
"enable_thinking", True
)
reasoning_parser = self.tokenizer_manager.server_args.reasoning_parser
if reasoning_parser and request.separate_reasoning and enable_thinking:
try:
parser = ReasoningParser(
@@ -708,6 +633,7 @@ class OpenAIServingChat(OpenAIServingBase):
# Handle tool calls
tool_calls = None
if request.tool_choice != "none" and request.tools:
tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser
tool_calls, text, finish_reason = self._process_tool_calls(
text, request.tools, tool_call_parser, finish_reason
)
@@ -731,6 +657,7 @@ class OpenAIServingChat(OpenAIServingBase):
choices.append(choice_data)
# Calculate usage
cache_report = self.tokenizer_manager.server_args.enable_cache_report
usage = aggregate_token_usage(ret, request.n, cache_report)
return ChatCompletionResponse(
@@ -810,7 +737,7 @@ class OpenAIServingChat(OpenAIServingBase):
text, call_info_list = parser.parse_non_stream(text)
tool_calls = [
ToolCall(
id=f"call_{base64.urlsafe_b64encode(uuid.uuid4().bytes).rstrip(b'=').decode()}",
id=f"call_{uuid.uuid4().hex[:24]}",
function=FunctionResponse(
name=call_info.name, arguments=call_info.parameters
),
@@ -894,6 +821,16 @@ class OpenAIServingChat(OpenAIServingBase):
# Yield tool calls
for call_item in calls:
# Tool call ID should be generated only once per tool call
if call_item.name:
# First chunk: include ID and function name
tool_call_id = f"call_{uuid.uuid4().hex[:24]}"
function_name = call_item.name
else:
# Subsequent chunks: null ID and name for argument deltas
tool_call_id = None
function_name = None
if finish_reason_type == "stop":
# Handle remaining arguments
latest_delta_len = 0
@@ -912,10 +849,10 @@ class OpenAIServingChat(OpenAIServingBase):
finish_reason_type = "tool_calls"
tool_call = ToolCall(
id=f"call_{base64.urlsafe_b64encode(uuid.uuid4().bytes).rstrip(b'=').decode()}",
id=tool_call_id,
index=call_item.tool_index,
function=FunctionResponse(
name=call_item.name,
name=function_name,
arguments=call_item.parameters,
),
)

View File

@@ -1,5 +1,6 @@
import logging
import time
from typing import Any, Dict, List, Optional, Union
from typing import Any, AsyncGenerator, Dict, List, Union
from fastapi import Request
from fastapi.responses import StreamingResponse
@@ -23,6 +24,8 @@ from sglang.srt.entrypoints.openai.utils import (
)
from sglang.srt.managers.io_struct import GenerateReqInput
logger = logging.getLogger(__name__)
class OpenAIServingCompletion(OpenAIServingBase):
"""Handler for completion requests"""
@@ -30,134 +33,54 @@ class OpenAIServingCompletion(OpenAIServingBase):
def _request_id_prefix(self) -> str:
return "cmpl-"
def _validate_request(self, request: CompletionRequest) -> Optional[str]:
"""Validate completion prompt format and content"""
if not (prompt := request.prompt):
return "Prompt cannot be None"
if isinstance(prompt, str):
if not prompt.strip():
return "Prompt cannot be empty or whitespace only"
elif isinstance(prompt, list):
if not prompt:
return "Prompt list cannot be empty"
# Check if it's a list of strings
if all(isinstance(item, str) for item in prompt):
for i, item in enumerate(prompt):
if not item.strip():
return f"Prompt at index {i} cannot be empty or whitespace only"
# Check if it's a list of token IDs (integers)
elif all(isinstance(item, int) for item in prompt):
if any(item < 0 for item in prompt):
return "Token IDs must be non-negative"
# Check if it's a list of lists (multiple token sequences)
elif all(isinstance(item, list) for item in prompt):
for i, item in enumerate(prompt):
if not item:
return f"Token sequence at index {i} cannot be empty"
if not all(isinstance(token, int) for token in item):
return f"Token sequence at index {i} must contain only integers"
if any(token < 0 for token in item):
return (
f"Token sequence at index {i} contains negative token IDs"
)
else:
return "Prompt must be string, list of strings, list of integers, or list of integer lists"
else:
return "Prompt must be string or list"
return None
def _convert_to_internal_request(
self,
all_requests: List[CompletionRequest],
request_ids: List[str],
) -> tuple[GenerateReqInput, Union[CompletionRequest, List[CompletionRequest]]]:
request: CompletionRequest,
) -> tuple[GenerateReqInput, CompletionRequest]:
"""Convert OpenAI completion request to internal format"""
# Validate batch requests
if len(all_requests) > 1:
first_prompt_type = type(all_requests[0].prompt)
for request in all_requests:
assert (
type(request.prompt) is first_prompt_type
), "All prompts must be of the same type in file input settings"
if request.n > 1:
raise ValueError(
"Parallel sampling is not supported for completions from files"
)
prompts = []
sampling_params_list = []
return_logprobs = []
logprob_start_lens = []
top_logprobs_nums = []
lora_paths = []
for request in all_requests:
# Process prompt
prompt = request.prompt
if is_completion_template_defined():
prompt = generate_completion_prompt_from_request(request)
prompts.append(prompt)
lora_paths.append(request.lora_path)
# Set logprob start length based on echo and logprobs
if request.echo and request.logprobs:
current_logprob_start_len = 0
else:
current_logprob_start_len = -1
# Build sampling parameters
sampling_params = self._build_sampling_params(request)
sampling_params_list.append(sampling_params)
return_logprobs.append(request.logprobs is not None)
logprob_start_lens.append(current_logprob_start_len)
top_logprobs_nums.append(
request.logprobs if request.logprobs is not None else 0
# NOTE: with openai API, the prompt's logprobs are always not computed
if request.echo and request.logprobs:
logger.warning(
"Echo is not compatible with logprobs. "
"To compute logprobs of input prompt, please use the native /generate API."
)
# Process prompt
prompt = request.prompt
if is_completion_template_defined():
prompt = generate_completion_prompt_from_request(request)
# Handle single vs multiple requests
if len(all_requests) == 1:
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
prompt_kwargs = {"text": prompts[0]}
else:
prompt_kwargs = {"input_ids": prompts[0]}
sampling_params_list = sampling_params_list[0]
return_logprobs = return_logprobs[0]
logprob_start_lens = logprob_start_lens[0]
top_logprobs_nums = top_logprobs_nums[0]
lora_paths = lora_paths[0]
request_ids = request_ids[0]
# Set logprob start length based on echo and logprobs
if request.echo and request.logprobs:
logprob_start_len = 0
else:
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
prompt_kwargs = {"text": prompts}
else:
prompt_kwargs = {"input_ids": prompts}
logprob_start_len = -1
# Build sampling parameters
sampling_params = self._build_sampling_params(request)
# Determine prompt format
if isinstance(prompt, str) or (
isinstance(prompt, list) and isinstance(prompt[0], str)
):
prompt_kwargs = {"text": prompt}
else:
prompt_kwargs = {"input_ids": prompt}
adapted_request = GenerateReqInput(
**prompt_kwargs,
sampling_params=sampling_params_list,
return_logprob=return_logprobs,
top_logprobs_num=top_logprobs_nums,
logprob_start_len=logprob_start_lens,
sampling_params=sampling_params,
return_logprob=request.logprobs is not None,
top_logprobs_num=request.logprobs if request.logprobs is not None else 0,
logprob_start_len=logprob_start_len,
return_text_in_logprobs=True,
stream=all_requests[0].stream,
rid=request_ids,
lora_path=lora_paths,
bootstrap_host=all_requests[0].bootstrap_host,
bootstrap_port=all_requests[0].bootstrap_port,
bootstrap_room=all_requests[0].bootstrap_room,
stream=request.stream,
lora_path=request.lora_path,
bootstrap_host=request.bootstrap_host,
bootstrap_port=request.bootstrap_port,
bootstrap_room=request.bootstrap_room,
)
return adapted_request, (
all_requests if len(all_requests) > 1 else all_requests[0]
)
return adapted_request, request
def _build_sampling_params(self, request: CompletionRequest) -> Dict[str, Any]:
"""Build sampling parameters for the request"""
@@ -184,9 +107,6 @@ class OpenAIServingCompletion(OpenAIServingBase):
"logit_bias": request.logit_bias,
}
# No additional completion-specific parameters needed currently
# (json_schema is already handled in base method)
return sampling_params
async def _handle_streaming_request(
@@ -196,122 +116,126 @@ class OpenAIServingCompletion(OpenAIServingBase):
raw_request: Request,
) -> StreamingResponse:
"""Handle streaming completion request"""
created = int(time.time())
async def generate_stream_resp():
stream_buffers = {}
n_prev_tokens = {}
prompt_tokens = {}
completion_tokens = {}
cached_tokens = {}
try:
async for content in self.tokenizer_manager.generate_request(
adapted_request, raw_request
):
index = content.get("index", 0)
stream_buffer = stream_buffers.get(index, "")
n_prev_token = n_prev_tokens.get(index, 0)
text = content["text"]
prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
completion_tokens[index] = content["meta_info"]["completion_tokens"]
cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
# Handle echo for first chunk
if not stream_buffer: # The first chunk
if request.echo:
echo_text = self._get_echo_text(request, index)
text = echo_text + text
# Handle logprobs
logprobs = None
if request.logprobs is not None:
# The first chunk and echo is enabled.
if not stream_buffer and request.echo:
input_token_logprobs = content["meta_info"][
"input_token_logprobs"
]
input_top_logprobs = content["meta_info"][
"input_top_logprobs"
]
else:
input_token_logprobs = None
input_top_logprobs = None
logprobs = to_openai_style_logprobs(
input_token_logprobs=input_token_logprobs,
input_top_logprobs=input_top_logprobs,
output_token_logprobs=content["meta_info"][
"output_token_logprobs"
][n_prev_token:],
output_top_logprobs=content["meta_info"][
"output_top_logprobs"
][n_prev_token:],
)
n_prev_token = len(
content["meta_info"]["output_token_logprobs"]
)
# Generate delta
delta = text[len(stream_buffer) :]
stream_buffer = stream_buffer + delta
finish_reason = content["meta_info"]["finish_reason"]
choice_data = CompletionResponseStreamChoice(
index=index,
text=delta,
logprobs=logprobs,
finish_reason=finish_reason["type"] if finish_reason else None,
matched_stop=(
finish_reason["matched"]
if finish_reason and "matched" in finish_reason
else None
),
)
chunk = CompletionStreamResponse(
id=content["meta_info"]["id"],
created=created,
object="text_completion",
choices=[choice_data],
model=request.model,
)
stream_buffers[index] = stream_buffer
n_prev_tokens[index] = n_prev_token
yield f"data: {chunk.model_dump_json()}\n\n"
# Handle final usage chunk
if request.stream_options and request.stream_options.include_usage:
usage = self._calculate_streaming_usage_base(
prompt_tokens, completion_tokens, cached_tokens, request.n
)
final_usage_chunk = CompletionStreamResponse(
id=content["meta_info"]["id"],
created=created,
choices=[],
model=request.model,
usage=usage,
)
final_usage_data = final_usage_chunk.model_dump_json(
exclude_none=True
)
yield f"data: {final_usage_data}\n\n"
except Exception as e:
error = self.create_streaming_error_response(str(e))
yield f"data: {error}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(
generate_stream_resp(),
self._generate_completion_stream(adapted_request, request, raw_request),
media_type="text/event-stream",
background=self.tokenizer_manager.create_abort_task(adapted_request),
)
async def _generate_completion_stream(
self,
adapted_request: GenerateReqInput,
request: CompletionRequest,
raw_request: Request,
) -> AsyncGenerator[str, None]:
"""Generate streaming completion response"""
created = int(time.time())
# State tracking for streaming
stream_buffers = {}
n_prev_tokens = {}
# Usage tracking
prompt_tokens = {}
completion_tokens = {}
cached_tokens = {}
try:
async for content in self.tokenizer_manager.generate_request(
adapted_request, raw_request
):
index = content.get("index", 0)
text = content["text"]
prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
completion_tokens[index] = content["meta_info"]["completion_tokens"]
cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
stream_buffer = stream_buffers.get(index, "")
# Handle echo for first chunk
if not stream_buffer: # The first chunk
if request.echo:
echo_text = self._get_echo_text(request, index)
text = echo_text + text
# Handle logprobs
logprobs = None
if request.logprobs is not None:
# The first chunk and echo is enabled.
if not stream_buffer and request.echo:
input_token_logprobs = content["meta_info"][
"input_token_logprobs"
]
input_top_logprobs = content["meta_info"]["input_top_logprobs"]
else:
input_token_logprobs = None
input_top_logprobs = None
n_prev_token = n_prev_tokens.get(index, 0)
logprobs = to_openai_style_logprobs(
input_token_logprobs=input_token_logprobs,
input_top_logprobs=input_top_logprobs,
output_token_logprobs=content["meta_info"][
"output_token_logprobs"
][n_prev_token:],
output_top_logprobs=content["meta_info"]["output_top_logprobs"][
n_prev_token:
],
)
n_prev_tokens[index] = len(
content["meta_info"]["output_token_logprobs"]
)
# Generate delta
delta = text[len(stream_buffer) :]
stream_buffers[index] = stream_buffer + delta
finish_reason = content["meta_info"]["finish_reason"]
choice_data = CompletionResponseStreamChoice(
index=index,
text=delta,
logprobs=logprobs,
finish_reason=finish_reason["type"] if finish_reason else None,
matched_stop=(
finish_reason["matched"]
if finish_reason and "matched" in finish_reason
else None
),
)
chunk = CompletionStreamResponse(
id=content["meta_info"]["id"],
created=created,
object="text_completion",
choices=[choice_data],
model=request.model,
)
yield f"data: {chunk.model_dump_json()}\n\n"
# Handle final usage chunk
if request.stream_options and request.stream_options.include_usage:
usage = self._calculate_streaming_usage_base(
prompt_tokens,
completion_tokens,
cached_tokens,
request.n,
)
final_usage_chunk = CompletionStreamResponse(
id=content["meta_info"]["id"],
created=created,
choices=[],
model=request.model,
usage=usage,
)
final_usage_data = final_usage_chunk.model_dump_json(exclude_none=True)
yield f"data: {final_usage_data}\n\n"
except Exception as e:
error = self.create_streaming_error_response(str(e))
yield f"data: {error}\n\n"
yield "data: [DONE]\n\n"
async def _handle_non_streaming_request(
self,
adapted_request: GenerateReqInput,
@@ -334,7 +258,6 @@ class OpenAIServingCompletion(OpenAIServingBase):
request,
ret,
int(time.time()),
cache_report=self.tokenizer_manager.server_args.enable_cache_report,
)
return response
@@ -344,7 +267,6 @@ class OpenAIServingCompletion(OpenAIServingBase):
request: CompletionRequest,
ret: List[Dict[str, Any]],
created: int,
cache_report: bool = False,
) -> CompletionResponse:
"""Build completion response from generation results"""
choices = []
@@ -352,7 +274,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
# Prepare echo prompts if needed
echo_prompts = []
if (not isinstance(request, list)) and request.echo:
if request.echo:
echo_prompts = self._prepare_echo_prompts(request)
echo = True
@@ -360,21 +282,13 @@ class OpenAIServingCompletion(OpenAIServingBase):
text = ret_item["text"]
# Handle echo
if isinstance(request, list) and request[idx].echo:
echo = True
text = request[idx].prompt + text
elif echo and not isinstance(request, list):
if echo:
prompt_index = idx // request.n
text = echo_prompts[prompt_index] + text
# Handle logprobs
logprobs = None
if isinstance(request, list) and request[idx].logprobs is not None:
logprobs = True
elif (not isinstance(request, list)) and request.logprobs is not None:
logprobs = True
if logprobs:
if request.logprobs is not None:
if echo:
input_token_logprobs = ret_item["meta_info"]["input_token_logprobs"]
input_top_logprobs = ret_item["meta_info"]["input_top_logprobs"]
@@ -407,6 +321,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
choices.append(choice_data)
# Calculate usage
cache_report = self.tokenizer_manager.server_args.enable_cache_report
usage = aggregate_token_usage(ret, request.n, cache_report)
return CompletionResponse(

View File

@@ -54,35 +54,25 @@ class OpenAIServingEmbedding(OpenAIServingBase):
return f"All items in input list must be integers"
if item < 0:
return f"Token ID at index {i} must be non-negative"
elif isinstance(first_item, list):
# List of lists (multiple token sequences)
for i, item in enumerate(input):
if not isinstance(item, list):
return f"Input at index {i} must be a list"
if not item:
return f"Input at index {i} cannot be empty"
if not all(isinstance(token, int) for token in item):
return f"Input at index {i} must contain only integers"
if any(token < 0 for token in item):
return f"Input at index {i} contains negative token IDs"
# Note: MultimodalEmbeddingInput validation would be handled by Pydantic
return None
def _convert_to_internal_request(
self,
request: EmbeddingRequest,
request_id: str,
) -> tuple[EmbeddingReqInput, Union[EmbeddingRequest, List[EmbeddingRequest]]]:
) -> tuple[EmbeddingReqInput, EmbeddingRequest]:
"""Convert OpenAI embedding request to internal format"""
prompt = request.input
if isinstance(prompt, str):
# Single string input
prompt_kwargs = {"text": prompt}
elif isinstance(prompt, list):
if len(prompt) > 0 and isinstance(prompt[0], str):
# List of strings
prompt_kwargs = {"text": prompt}
# List of strings - if it's a single string in a list, treat as single string
if len(prompt) == 1:
prompt_kwargs = {"text": prompt[0]}
else:
prompt_kwargs = {"text": prompt}
elif len(prompt) > 0 and isinstance(prompt[0], MultimodalEmbeddingInput):
# Handle multimodal embedding inputs
texts = []
@@ -94,7 +84,6 @@ class OpenAIServingEmbedding(OpenAIServingBase):
generate_prompts = []
# Check if we have a chat template for multimodal embeddings
# This would need to be passed in from the server configuration
chat_template_name = getattr(
self.tokenizer_manager, "chat_template_name", None
)
@@ -121,6 +110,7 @@ class OpenAIServingEmbedding(OpenAIServingBase):
else:
# Other types (should not happen but handle gracefully)
prompt_kwargs = {"input_ids": prompt}
adapted_request = EmbeddingReqInput(
**prompt_kwargs,
)