Refine OpenAI serving entrypoint to remove batch requests (#7372)
Signed-off-by: Xinyuan Tong <justinning0323@outlook.com> Co-authored-by: Chang Su <csu272@usc.edu>
This commit is contained in:
@@ -20,7 +20,7 @@ import logging
|
||||
import os
|
||||
from enum import auto
|
||||
|
||||
from sglang.srt.openai_api.protocol import ChatCompletionRequest
|
||||
from sglang.srt.entrypoints.openai.protocol import CompletionRequest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
completion_template_name = None
|
||||
@@ -116,7 +116,7 @@ def is_completion_template_defined() -> bool:
|
||||
return completion_template_name is not None
|
||||
|
||||
|
||||
def generate_completion_prompt_from_request(request: ChatCompletionRequest) -> str:
|
||||
def generate_completion_prompt_from_request(request: CompletionRequest) -> str:
|
||||
global completion_template_name
|
||||
if request.suffix == "":
|
||||
return request.prompt
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
import logging
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.responses import ORJSONResponse, StreamingResponse
|
||||
@@ -37,7 +37,7 @@ class OpenAIServingBase(ABC):
|
||||
|
||||
# Convert to internal format
|
||||
adapted_request, processed_request = self._convert_to_internal_request(
|
||||
request, self._generate_request_id_base(request)
|
||||
request
|
||||
)
|
||||
|
||||
# Note(Xinyuan): raw_request below is only used for detecting the connection of the client
|
||||
@@ -74,10 +74,7 @@ class OpenAIServingBase(ABC):
|
||||
def _convert_to_internal_request(
|
||||
self,
|
||||
request: OpenAIServingRequest,
|
||||
request_id: str,
|
||||
) -> tuple[
|
||||
GenerateReqInput, Union[OpenAIServingRequest, List[OpenAIServingRequest]]
|
||||
]:
|
||||
) -> tuple[GenerateReqInput, OpenAIServingRequest]:
|
||||
"""Convert OpenAI request to internal format"""
|
||||
pass
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
@@ -52,137 +52,56 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
def _request_id_prefix(self) -> str:
|
||||
return "chatcmpl-"
|
||||
|
||||
def _validate_request(self, request: ChatCompletionRequest) -> Optional[str]:
|
||||
"""Validate chat messages format and content"""
|
||||
if not (messages := request.messages):
|
||||
return "Messages cannot be empty"
|
||||
|
||||
# Check for alternating user/assistant pattern (optional validation)
|
||||
roles = [msg.role for msg in messages]
|
||||
|
||||
# First message should typically be from user or system
|
||||
if roles[0] not in ["user", "system"]:
|
||||
return "First message should be from 'user' or 'system'"
|
||||
|
||||
# Check for consecutive assistant messages (which might indicate an error)
|
||||
for i in range(1, len(roles)):
|
||||
if roles[i] == "assistant" and roles[i - 1] == "assistant":
|
||||
# This is actually allowed in some cases, so just warn
|
||||
pass
|
||||
|
||||
# Validate message content
|
||||
for i, msg in enumerate(messages):
|
||||
if msg.role == "user":
|
||||
if not msg.content:
|
||||
return f"User message at index {i} has no content"
|
||||
elif msg.role == "assistant":
|
||||
# Assistant messages can have no content if they have tool_calls
|
||||
if not msg.content and not getattr(msg, "tool_calls", None):
|
||||
return (
|
||||
f"Assistant message at index {i} has no content or tool calls"
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def _convert_to_internal_request(
|
||||
self,
|
||||
all_requests: List[ChatCompletionRequest],
|
||||
request_ids: List[str],
|
||||
) -> tuple[
|
||||
GenerateReqInput, Union[ChatCompletionRequest, List[ChatCompletionRequest]]
|
||||
]:
|
||||
request: ChatCompletionRequest,
|
||||
) -> tuple[GenerateReqInput, ChatCompletionRequest]:
|
||||
"""Convert OpenAI chat completion request to internal format"""
|
||||
input_ids = []
|
||||
prompts = []
|
||||
sampling_params_list = []
|
||||
image_data_list = []
|
||||
audio_data_list = []
|
||||
return_logprobs = []
|
||||
logprob_start_lens = []
|
||||
top_logprobs_nums = []
|
||||
modalities_list = []
|
||||
lora_paths = []
|
||||
|
||||
is_multimodal = self.tokenizer_manager.model_config.is_multimodal
|
||||
|
||||
for request in all_requests:
|
||||
# Process messages and apply chat template
|
||||
(
|
||||
prompt,
|
||||
prompt_ids,
|
||||
image_data,
|
||||
audio_data,
|
||||
modalities,
|
||||
stop,
|
||||
tool_call_constraint,
|
||||
) = self._process_messages(request, is_multimodal)
|
||||
# Process messages and apply chat template
|
||||
(
|
||||
prompt,
|
||||
prompt_ids,
|
||||
image_data,
|
||||
audio_data,
|
||||
modalities,
|
||||
stop,
|
||||
tool_call_constraint,
|
||||
) = self._process_messages(request, is_multimodal)
|
||||
|
||||
input_ids.append(prompt_ids)
|
||||
prompts.append(prompt)
|
||||
return_logprobs.append(request.logprobs)
|
||||
logprob_start_lens.append(-1)
|
||||
top_logprobs_nums.append(request.top_logprobs or 0)
|
||||
lora_paths.append(request.lora_path)
|
||||
|
||||
# Build sampling parameters
|
||||
sampling_params = self._build_sampling_params(
|
||||
request, stop, tool_call_constraint
|
||||
)
|
||||
sampling_params_list.append(sampling_params)
|
||||
|
||||
image_data_list.append(image_data)
|
||||
audio_data_list.append(audio_data)
|
||||
modalities_list.append(modalities)
|
||||
# Build sampling parameters
|
||||
sampling_params = self._build_sampling_params(
|
||||
request, stop, tool_call_constraint
|
||||
)
|
||||
|
||||
# Handle single vs multiple requests
|
||||
if len(all_requests) == 1:
|
||||
if is_multimodal:
|
||||
prompt_kwargs = {"text": prompts[0]}
|
||||
else:
|
||||
if isinstance(input_ids[0], str):
|
||||
prompt_kwargs = {"text": input_ids[0]}
|
||||
else:
|
||||
prompt_kwargs = {"input_ids": input_ids[0]}
|
||||
|
||||
sampling_params_list = sampling_params_list[0]
|
||||
image_data_list = image_data_list[0]
|
||||
audio_data_list = audio_data_list[0]
|
||||
return_logprobs = return_logprobs[0]
|
||||
logprob_start_lens = logprob_start_lens[0]
|
||||
top_logprobs_nums = top_logprobs_nums[0]
|
||||
modalities_list = modalities_list[0]
|
||||
lora_paths = lora_paths[0]
|
||||
request_ids = request_ids[0]
|
||||
if is_multimodal:
|
||||
prompt_kwargs = {"text": prompt}
|
||||
else:
|
||||
if is_multimodal:
|
||||
prompt_kwargs = {"text": prompts}
|
||||
if isinstance(prompt_ids, str):
|
||||
prompt_kwargs = {"text": prompt_ids}
|
||||
else:
|
||||
if isinstance(input_ids[0], str):
|
||||
prompt_kwargs = {"text": input_ids}
|
||||
else:
|
||||
prompt_kwargs = {"input_ids": input_ids}
|
||||
prompt_kwargs = {"input_ids": prompt_ids}
|
||||
|
||||
adapted_request = GenerateReqInput(
|
||||
**prompt_kwargs,
|
||||
image_data=image_data_list,
|
||||
audio_data=audio_data_list,
|
||||
sampling_params=sampling_params_list,
|
||||
return_logprob=return_logprobs,
|
||||
logprob_start_len=logprob_start_lens,
|
||||
top_logprobs_num=top_logprobs_nums,
|
||||
stream=all_requests[0].stream,
|
||||
image_data=image_data,
|
||||
audio_data=audio_data,
|
||||
sampling_params=sampling_params,
|
||||
return_logprob=request.logprobs,
|
||||
logprob_start_len=-1,
|
||||
top_logprobs_num=request.top_logprobs or 0,
|
||||
stream=request.stream,
|
||||
return_text_in_logprobs=True,
|
||||
rid=request_ids,
|
||||
modalities=modalities_list,
|
||||
lora_path=lora_paths,
|
||||
bootstrap_host=all_requests[0].bootstrap_host,
|
||||
bootstrap_port=all_requests[0].bootstrap_port,
|
||||
bootstrap_room=all_requests[0].bootstrap_room,
|
||||
modalities=modalities,
|
||||
lora_path=request.lora_path,
|
||||
bootstrap_host=request.bootstrap_host,
|
||||
bootstrap_port=request.bootstrap_port,
|
||||
bootstrap_room=request.bootstrap_room,
|
||||
)
|
||||
|
||||
return adapted_request, (
|
||||
all_requests if len(all_requests) > 1 else all_requests[0]
|
||||
)
|
||||
return adapted_request, request
|
||||
|
||||
def _process_messages(
|
||||
self, request: ChatCompletionRequest, is_multimodal: bool
|
||||
@@ -457,55 +376,138 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
raw_request: Request,
|
||||
) -> StreamingResponse:
|
||||
"""Handle streaming chat completion request"""
|
||||
return StreamingResponse(
|
||||
self._generate_chat_stream(adapted_request, request, raw_request),
|
||||
media_type="text/event-stream",
|
||||
background=self.tokenizer_manager.create_abort_task(adapted_request),
|
||||
)
|
||||
|
||||
async def generate_stream_resp():
|
||||
parser_dict = {}
|
||||
reasoning_parser_dict = {}
|
||||
tool_call_first = True
|
||||
is_firsts = {}
|
||||
stream_buffers = {}
|
||||
n_prev_tokens = {}
|
||||
prompt_tokens = {}
|
||||
completion_tokens = {}
|
||||
cached_tokens = {}
|
||||
async def _generate_chat_stream(
|
||||
self,
|
||||
adapted_request: GenerateReqInput,
|
||||
request: ChatCompletionRequest,
|
||||
raw_request: Request,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate streaming chat completion response"""
|
||||
# Parsers for tool calls and reasoning
|
||||
parser_dict = {}
|
||||
reasoning_parser_dict = {}
|
||||
|
||||
try:
|
||||
async for content in self.tokenizer_manager.generate_request(
|
||||
adapted_request, raw_request
|
||||
):
|
||||
index = content.get("index", 0)
|
||||
# State tracking for streaming
|
||||
is_firsts = {}
|
||||
stream_buffers = {}
|
||||
n_prev_tokens = {}
|
||||
|
||||
is_first = is_firsts.get(index, True)
|
||||
stream_buffer = stream_buffers.get(index, "")
|
||||
n_prev_token = n_prev_tokens.get(index, 0)
|
||||
# Usage tracking
|
||||
prompt_tokens = {}
|
||||
completion_tokens = {}
|
||||
cached_tokens = {}
|
||||
|
||||
prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
|
||||
completion_tokens[index] = content["meta_info"]["completion_tokens"]
|
||||
cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
|
||||
try:
|
||||
async for content in self.tokenizer_manager.generate_request(
|
||||
adapted_request, raw_request
|
||||
):
|
||||
index = content.get("index", 0)
|
||||
|
||||
# Handle logprobs
|
||||
choice_logprobs = None
|
||||
if request.logprobs:
|
||||
choice_logprobs = self._process_streaming_logprobs(
|
||||
content, n_prev_token
|
||||
)
|
||||
n_prev_token = len(
|
||||
content["meta_info"]["output_token_logprobs"]
|
||||
)
|
||||
prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
|
||||
completion_tokens[index] = content["meta_info"]["completion_tokens"]
|
||||
cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
|
||||
|
||||
finish_reason = content["meta_info"]["finish_reason"]
|
||||
finish_reason_type = (
|
||||
finish_reason["type"] if finish_reason else None
|
||||
# Handle logprobs
|
||||
choice_logprobs = None
|
||||
if request.logprobs:
|
||||
choice_logprobs = self._process_streaming_logprobs(
|
||||
content, n_prev_tokens.get(index, 0)
|
||||
)
|
||||
n_prev_tokens[index] = len(
|
||||
content["meta_info"]["output_token_logprobs"]
|
||||
)
|
||||
|
||||
# First chunk with role
|
||||
if is_first:
|
||||
is_first = False
|
||||
delta = DeltaMessage(role="assistant")
|
||||
finish_reason = content["meta_info"]["finish_reason"]
|
||||
finish_reason_type = finish_reason["type"] if finish_reason else None
|
||||
|
||||
# First chunk with role
|
||||
if is_firsts.get(index, True):
|
||||
is_firsts[index] = False
|
||||
delta = DeltaMessage(role="assistant", content="")
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=index,
|
||||
delta=delta,
|
||||
finish_reason=finish_reason_type,
|
||||
matched_stop=(
|
||||
finish_reason["matched"]
|
||||
if finish_reason and "matched" in finish_reason
|
||||
else None
|
||||
),
|
||||
logprobs=choice_logprobs,
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=content["meta_info"]["id"],
|
||||
created=int(time.time()),
|
||||
choices=[choice_data],
|
||||
model=request.model,
|
||||
)
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
|
||||
# Process content delta
|
||||
stream_buffer = stream_buffers.get(index, "")
|
||||
delta = content["text"][len(stream_buffer) :]
|
||||
stream_buffers[index] = stream_buffer + delta
|
||||
|
||||
# Handle reasoning content
|
||||
enable_thinking = getattr(request, "chat_template_kwargs", {}).get(
|
||||
"enable_thinking", True
|
||||
)
|
||||
if (
|
||||
self.tokenizer_manager.server_args.reasoning_parser
|
||||
and request.separate_reasoning
|
||||
and enable_thinking
|
||||
):
|
||||
reasoning_text, delta = self._process_reasoning_stream(
|
||||
index, delta, reasoning_parser_dict, content, request
|
||||
)
|
||||
if reasoning_text:
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=index,
|
||||
delta=delta,
|
||||
delta=DeltaMessage(reasoning_content=reasoning_text),
|
||||
finish_reason=finish_reason_type,
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=content["meta_info"]["id"],
|
||||
created=int(time.time()),
|
||||
choices=[choice_data],
|
||||
model=request.model,
|
||||
)
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
|
||||
if not delta:
|
||||
continue
|
||||
|
||||
# Handle tool calls
|
||||
if request.tool_choice != "none" and request.tools:
|
||||
async for chunk in self._process_tool_call_stream(
|
||||
index,
|
||||
delta,
|
||||
parser_dict,
|
||||
content,
|
||||
request,
|
||||
finish_reason_type,
|
||||
):
|
||||
yield chunk
|
||||
else:
|
||||
# Regular content
|
||||
if delta or not (
|
||||
request.stream_options and request.stream_options.include_usage
|
||||
):
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=index,
|
||||
delta=DeltaMessage(content=delta if delta else None),
|
||||
finish_reason=(
|
||||
None
|
||||
if request.stream_options
|
||||
and request.stream_options.include_usage
|
||||
else finish_reason_type
|
||||
),
|
||||
matched_stop=(
|
||||
finish_reason["matched"]
|
||||
if finish_reason and "matched" in finish_reason
|
||||
@@ -521,121 +523,49 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
)
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
|
||||
# Process content delta
|
||||
delta = content["text"][len(stream_buffer) :]
|
||||
new_stream_buffer = stream_buffer + delta
|
||||
|
||||
# Handle reasoning content
|
||||
enable_thinking = getattr(request, "chat_template_kwargs", {}).get(
|
||||
"enable_thinking", True
|
||||
# Final chunk with finish_reason
|
||||
finish_reason_chunk = ChatCompletionStreamResponse(
|
||||
id=content["meta_info"]["id"],
|
||||
created=int(time.time()),
|
||||
choices=[
|
||||
ChatCompletionResponseStreamChoice(
|
||||
index=index,
|
||||
delta=DeltaMessage(),
|
||||
finish_reason=finish_reason_type,
|
||||
matched_stop=(
|
||||
finish_reason["matched"]
|
||||
if finish_reason and "matched" in finish_reason
|
||||
else None
|
||||
),
|
||||
)
|
||||
if (
|
||||
self.tokenizer_manager.server_args.reasoning_parser
|
||||
and request.separate_reasoning
|
||||
and enable_thinking
|
||||
):
|
||||
reasoning_text, delta = self._process_reasoning_stream(
|
||||
index, delta, reasoning_parser_dict, content, request
|
||||
)
|
||||
if reasoning_text:
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=index,
|
||||
delta=DeltaMessage(reasoning_content=reasoning_text),
|
||||
finish_reason=finish_reason_type,
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=content["meta_info"]["id"],
|
||||
created=int(time.time()),
|
||||
choices=[choice_data],
|
||||
model=request.model,
|
||||
)
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
],
|
||||
model=request.model,
|
||||
usage=None,
|
||||
)
|
||||
yield f"data: {finish_reason_chunk.model_dump_json()}\n\n"
|
||||
|
||||
if not delta:
|
||||
stream_buffers[index] = new_stream_buffer
|
||||
is_firsts[index] = is_first
|
||||
n_prev_tokens[index] = n_prev_token
|
||||
continue
|
||||
|
||||
# Handle tool calls
|
||||
if request.tool_choice != "none" and request.tools:
|
||||
async for chunk in self._process_tool_call_stream(
|
||||
index,
|
||||
delta,
|
||||
parser_dict,
|
||||
content,
|
||||
request,
|
||||
finish_reason_type,
|
||||
):
|
||||
yield chunk
|
||||
else:
|
||||
# Regular content
|
||||
if delta or not (
|
||||
request.stream_options
|
||||
and request.stream_options.include_usage
|
||||
):
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=index,
|
||||
delta=DeltaMessage(content=delta if delta else None),
|
||||
finish_reason=(
|
||||
None
|
||||
if request.stream_options
|
||||
and request.stream_options.include_usage
|
||||
else finish_reason_type
|
||||
),
|
||||
matched_stop=(
|
||||
finish_reason["matched"]
|
||||
if finish_reason and "matched" in finish_reason
|
||||
else None
|
||||
),
|
||||
logprobs=choice_logprobs,
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=content["meta_info"]["id"],
|
||||
created=int(time.time()),
|
||||
choices=[choice_data],
|
||||
model=request.model,
|
||||
)
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
|
||||
stream_buffers[index] = new_stream_buffer
|
||||
is_firsts[index] = is_first
|
||||
n_prev_tokens[index] = n_prev_token
|
||||
|
||||
# Final chunk with usage
|
||||
if request.stream_options and request.stream_options.include_usage:
|
||||
usage = self._calculate_streaming_usage_base(
|
||||
prompt_tokens, completion_tokens, cached_tokens, request.n
|
||||
)
|
||||
else:
|
||||
usage = None
|
||||
|
||||
final_chunk = ChatCompletionStreamResponse(
|
||||
# Additional usage chunk
|
||||
if request.stream_options and request.stream_options.include_usage:
|
||||
usage = self._calculate_streaming_usage_base(
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
cached_tokens,
|
||||
request.n,
|
||||
)
|
||||
usage_chunk = ChatCompletionStreamResponse(
|
||||
id=content["meta_info"]["id"],
|
||||
created=int(time.time()),
|
||||
choices=[
|
||||
ChatCompletionResponseStreamChoice(
|
||||
index=index,
|
||||
delta=DeltaMessage(),
|
||||
finish_reason=finish_reason_type,
|
||||
)
|
||||
],
|
||||
choices=[], # Empty choices array as per OpenAI spec
|
||||
model=request.model,
|
||||
usage=usage,
|
||||
)
|
||||
yield f"data: {final_chunk.model_dump_json()}\n\n"
|
||||
yield f"data: {usage_chunk.model_dump_json()}\n\n"
|
||||
|
||||
except Exception as e:
|
||||
error = self.create_streaming_error_response(str(e))
|
||||
yield f"data: {error}\n\n"
|
||||
except Exception as e:
|
||||
error = self.create_streaming_error_response(str(e))
|
||||
yield f"data: {error}\n\n"
|
||||
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
generate_stream_resp(),
|
||||
media_type="text/event-stream",
|
||||
background=self.tokenizer_manager.create_abort_task(adapted_request),
|
||||
)
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def _handle_non_streaming_request(
|
||||
self,
|
||||
@@ -658,9 +588,6 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
request,
|
||||
ret,
|
||||
int(time.time()),
|
||||
cache_report=self.tokenizer_manager.server_args.enable_cache_report,
|
||||
tool_call_parser=self.tokenizer_manager.server_args.tool_call_parser,
|
||||
reasoning_parser=self.tokenizer_manager.server_args.reasoning_parser,
|
||||
)
|
||||
|
||||
return response
|
||||
@@ -670,9 +597,6 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
request: ChatCompletionRequest,
|
||||
ret: List[Dict[str, Any]],
|
||||
created: int,
|
||||
cache_report: bool = False,
|
||||
tool_call_parser: Optional[str] = None,
|
||||
reasoning_parser: Optional[str] = None,
|
||||
) -> ChatCompletionResponse:
|
||||
"""Build chat completion response from generation results"""
|
||||
choices = []
|
||||
@@ -691,6 +615,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
enable_thinking = getattr(request, "chat_template_kwargs", {}).get(
|
||||
"enable_thinking", True
|
||||
)
|
||||
reasoning_parser = self.tokenizer_manager.server_args.reasoning_parser
|
||||
if reasoning_parser and request.separate_reasoning and enable_thinking:
|
||||
try:
|
||||
parser = ReasoningParser(
|
||||
@@ -708,6 +633,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
# Handle tool calls
|
||||
tool_calls = None
|
||||
if request.tool_choice != "none" and request.tools:
|
||||
tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser
|
||||
tool_calls, text, finish_reason = self._process_tool_calls(
|
||||
text, request.tools, tool_call_parser, finish_reason
|
||||
)
|
||||
@@ -731,6 +657,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
choices.append(choice_data)
|
||||
|
||||
# Calculate usage
|
||||
cache_report = self.tokenizer_manager.server_args.enable_cache_report
|
||||
usage = aggregate_token_usage(ret, request.n, cache_report)
|
||||
|
||||
return ChatCompletionResponse(
|
||||
@@ -810,7 +737,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
text, call_info_list = parser.parse_non_stream(text)
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
id=f"call_{base64.urlsafe_b64encode(uuid.uuid4().bytes).rstrip(b'=').decode()}",
|
||||
id=f"call_{uuid.uuid4().hex[:24]}",
|
||||
function=FunctionResponse(
|
||||
name=call_info.name, arguments=call_info.parameters
|
||||
),
|
||||
@@ -894,6 +821,16 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
|
||||
# Yield tool calls
|
||||
for call_item in calls:
|
||||
# Tool call ID should be generated only once per tool call
|
||||
if call_item.name:
|
||||
# First chunk: include ID and function name
|
||||
tool_call_id = f"call_{uuid.uuid4().hex[:24]}"
|
||||
function_name = call_item.name
|
||||
else:
|
||||
# Subsequent chunks: null ID and name for argument deltas
|
||||
tool_call_id = None
|
||||
function_name = None
|
||||
|
||||
if finish_reason_type == "stop":
|
||||
# Handle remaining arguments
|
||||
latest_delta_len = 0
|
||||
@@ -912,10 +849,10 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
finish_reason_type = "tool_calls"
|
||||
|
||||
tool_call = ToolCall(
|
||||
id=f"call_{base64.urlsafe_b64encode(uuid.uuid4().bytes).rstrip(b'=').decode()}",
|
||||
id=tool_call_id,
|
||||
index=call_item.tool_index,
|
||||
function=FunctionResponse(
|
||||
name=call_item.name,
|
||||
name=function_name,
|
||||
arguments=call_item.parameters,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, AsyncGenerator, Dict, List, Union
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
@@ -23,6 +24,8 @@ from sglang.srt.entrypoints.openai.utils import (
|
||||
)
|
||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAIServingCompletion(OpenAIServingBase):
|
||||
"""Handler for completion requests"""
|
||||
@@ -30,134 +33,54 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
||||
def _request_id_prefix(self) -> str:
|
||||
return "cmpl-"
|
||||
|
||||
def _validate_request(self, request: CompletionRequest) -> Optional[str]:
|
||||
"""Validate completion prompt format and content"""
|
||||
if not (prompt := request.prompt):
|
||||
return "Prompt cannot be None"
|
||||
|
||||
if isinstance(prompt, str):
|
||||
if not prompt.strip():
|
||||
return "Prompt cannot be empty or whitespace only"
|
||||
elif isinstance(prompt, list):
|
||||
if not prompt:
|
||||
return "Prompt list cannot be empty"
|
||||
|
||||
# Check if it's a list of strings
|
||||
if all(isinstance(item, str) for item in prompt):
|
||||
for i, item in enumerate(prompt):
|
||||
if not item.strip():
|
||||
return f"Prompt at index {i} cannot be empty or whitespace only"
|
||||
|
||||
# Check if it's a list of token IDs (integers)
|
||||
elif all(isinstance(item, int) for item in prompt):
|
||||
if any(item < 0 for item in prompt):
|
||||
return "Token IDs must be non-negative"
|
||||
|
||||
# Check if it's a list of lists (multiple token sequences)
|
||||
elif all(isinstance(item, list) for item in prompt):
|
||||
for i, item in enumerate(prompt):
|
||||
if not item:
|
||||
return f"Token sequence at index {i} cannot be empty"
|
||||
if not all(isinstance(token, int) for token in item):
|
||||
return f"Token sequence at index {i} must contain only integers"
|
||||
if any(token < 0 for token in item):
|
||||
return (
|
||||
f"Token sequence at index {i} contains negative token IDs"
|
||||
)
|
||||
else:
|
||||
return "Prompt must be string, list of strings, list of integers, or list of integer lists"
|
||||
else:
|
||||
return "Prompt must be string or list"
|
||||
|
||||
return None
|
||||
|
||||
def _convert_to_internal_request(
|
||||
self,
|
||||
all_requests: List[CompletionRequest],
|
||||
request_ids: List[str],
|
||||
) -> tuple[GenerateReqInput, Union[CompletionRequest, List[CompletionRequest]]]:
|
||||
request: CompletionRequest,
|
||||
) -> tuple[GenerateReqInput, CompletionRequest]:
|
||||
"""Convert OpenAI completion request to internal format"""
|
||||
# Validate batch requests
|
||||
if len(all_requests) > 1:
|
||||
first_prompt_type = type(all_requests[0].prompt)
|
||||
for request in all_requests:
|
||||
assert (
|
||||
type(request.prompt) is first_prompt_type
|
||||
), "All prompts must be of the same type in file input settings"
|
||||
if request.n > 1:
|
||||
raise ValueError(
|
||||
"Parallel sampling is not supported for completions from files"
|
||||
)
|
||||
|
||||
prompts = []
|
||||
sampling_params_list = []
|
||||
return_logprobs = []
|
||||
logprob_start_lens = []
|
||||
top_logprobs_nums = []
|
||||
lora_paths = []
|
||||
|
||||
for request in all_requests:
|
||||
# Process prompt
|
||||
prompt = request.prompt
|
||||
if is_completion_template_defined():
|
||||
prompt = generate_completion_prompt_from_request(request)
|
||||
|
||||
prompts.append(prompt)
|
||||
|
||||
lora_paths.append(request.lora_path)
|
||||
|
||||
# Set logprob start length based on echo and logprobs
|
||||
if request.echo and request.logprobs:
|
||||
current_logprob_start_len = 0
|
||||
else:
|
||||
current_logprob_start_len = -1
|
||||
|
||||
# Build sampling parameters
|
||||
sampling_params = self._build_sampling_params(request)
|
||||
sampling_params_list.append(sampling_params)
|
||||
|
||||
return_logprobs.append(request.logprobs is not None)
|
||||
logprob_start_lens.append(current_logprob_start_len)
|
||||
top_logprobs_nums.append(
|
||||
request.logprobs if request.logprobs is not None else 0
|
||||
# NOTE: with openai API, the prompt's logprobs are always not computed
|
||||
if request.echo and request.logprobs:
|
||||
logger.warning(
|
||||
"Echo is not compatible with logprobs. "
|
||||
"To compute logprobs of input prompt, please use the native /generate API."
|
||||
)
|
||||
# Process prompt
|
||||
prompt = request.prompt
|
||||
if is_completion_template_defined():
|
||||
prompt = generate_completion_prompt_from_request(request)
|
||||
|
||||
# Handle single vs multiple requests
|
||||
if len(all_requests) == 1:
|
||||
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
|
||||
prompt_kwargs = {"text": prompts[0]}
|
||||
else:
|
||||
prompt_kwargs = {"input_ids": prompts[0]}
|
||||
sampling_params_list = sampling_params_list[0]
|
||||
return_logprobs = return_logprobs[0]
|
||||
logprob_start_lens = logprob_start_lens[0]
|
||||
top_logprobs_nums = top_logprobs_nums[0]
|
||||
lora_paths = lora_paths[0]
|
||||
request_ids = request_ids[0]
|
||||
# Set logprob start length based on echo and logprobs
|
||||
if request.echo and request.logprobs:
|
||||
logprob_start_len = 0
|
||||
else:
|
||||
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
|
||||
prompt_kwargs = {"text": prompts}
|
||||
else:
|
||||
prompt_kwargs = {"input_ids": prompts}
|
||||
logprob_start_len = -1
|
||||
|
||||
# Build sampling parameters
|
||||
sampling_params = self._build_sampling_params(request)
|
||||
|
||||
# Determine prompt format
|
||||
if isinstance(prompt, str) or (
|
||||
isinstance(prompt, list) and isinstance(prompt[0], str)
|
||||
):
|
||||
prompt_kwargs = {"text": prompt}
|
||||
else:
|
||||
prompt_kwargs = {"input_ids": prompt}
|
||||
|
||||
adapted_request = GenerateReqInput(
|
||||
**prompt_kwargs,
|
||||
sampling_params=sampling_params_list,
|
||||
return_logprob=return_logprobs,
|
||||
top_logprobs_num=top_logprobs_nums,
|
||||
logprob_start_len=logprob_start_lens,
|
||||
sampling_params=sampling_params,
|
||||
return_logprob=request.logprobs is not None,
|
||||
top_logprobs_num=request.logprobs if request.logprobs is not None else 0,
|
||||
logprob_start_len=logprob_start_len,
|
||||
return_text_in_logprobs=True,
|
||||
stream=all_requests[0].stream,
|
||||
rid=request_ids,
|
||||
lora_path=lora_paths,
|
||||
bootstrap_host=all_requests[0].bootstrap_host,
|
||||
bootstrap_port=all_requests[0].bootstrap_port,
|
||||
bootstrap_room=all_requests[0].bootstrap_room,
|
||||
stream=request.stream,
|
||||
lora_path=request.lora_path,
|
||||
bootstrap_host=request.bootstrap_host,
|
||||
bootstrap_port=request.bootstrap_port,
|
||||
bootstrap_room=request.bootstrap_room,
|
||||
)
|
||||
|
||||
return adapted_request, (
|
||||
all_requests if len(all_requests) > 1 else all_requests[0]
|
||||
)
|
||||
return adapted_request, request
|
||||
|
||||
def _build_sampling_params(self, request: CompletionRequest) -> Dict[str, Any]:
|
||||
"""Build sampling parameters for the request"""
|
||||
@@ -184,9 +107,6 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
||||
"logit_bias": request.logit_bias,
|
||||
}
|
||||
|
||||
# No additional completion-specific parameters needed currently
|
||||
# (json_schema is already handled in base method)
|
||||
|
||||
return sampling_params
|
||||
|
||||
async def _handle_streaming_request(
|
||||
@@ -196,122 +116,126 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
||||
raw_request: Request,
|
||||
) -> StreamingResponse:
|
||||
"""Handle streaming completion request"""
|
||||
created = int(time.time())
|
||||
|
||||
async def generate_stream_resp():
|
||||
stream_buffers = {}
|
||||
n_prev_tokens = {}
|
||||
prompt_tokens = {}
|
||||
completion_tokens = {}
|
||||
cached_tokens = {}
|
||||
|
||||
try:
|
||||
async for content in self.tokenizer_manager.generate_request(
|
||||
adapted_request, raw_request
|
||||
):
|
||||
index = content.get("index", 0)
|
||||
|
||||
stream_buffer = stream_buffers.get(index, "")
|
||||
n_prev_token = n_prev_tokens.get(index, 0)
|
||||
|
||||
text = content["text"]
|
||||
prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
|
||||
completion_tokens[index] = content["meta_info"]["completion_tokens"]
|
||||
cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
|
||||
|
||||
# Handle echo for first chunk
|
||||
if not stream_buffer: # The first chunk
|
||||
if request.echo:
|
||||
echo_text = self._get_echo_text(request, index)
|
||||
text = echo_text + text
|
||||
|
||||
# Handle logprobs
|
||||
logprobs = None
|
||||
if request.logprobs is not None:
|
||||
# The first chunk and echo is enabled.
|
||||
if not stream_buffer and request.echo:
|
||||
input_token_logprobs = content["meta_info"][
|
||||
"input_token_logprobs"
|
||||
]
|
||||
input_top_logprobs = content["meta_info"][
|
||||
"input_top_logprobs"
|
||||
]
|
||||
else:
|
||||
input_token_logprobs = None
|
||||
input_top_logprobs = None
|
||||
|
||||
logprobs = to_openai_style_logprobs(
|
||||
input_token_logprobs=input_token_logprobs,
|
||||
input_top_logprobs=input_top_logprobs,
|
||||
output_token_logprobs=content["meta_info"][
|
||||
"output_token_logprobs"
|
||||
][n_prev_token:],
|
||||
output_top_logprobs=content["meta_info"][
|
||||
"output_top_logprobs"
|
||||
][n_prev_token:],
|
||||
)
|
||||
n_prev_token = len(
|
||||
content["meta_info"]["output_token_logprobs"]
|
||||
)
|
||||
|
||||
# Generate delta
|
||||
delta = text[len(stream_buffer) :]
|
||||
stream_buffer = stream_buffer + delta
|
||||
finish_reason = content["meta_info"]["finish_reason"]
|
||||
|
||||
choice_data = CompletionResponseStreamChoice(
|
||||
index=index,
|
||||
text=delta,
|
||||
logprobs=logprobs,
|
||||
finish_reason=finish_reason["type"] if finish_reason else None,
|
||||
matched_stop=(
|
||||
finish_reason["matched"]
|
||||
if finish_reason and "matched" in finish_reason
|
||||
else None
|
||||
),
|
||||
)
|
||||
chunk = CompletionStreamResponse(
|
||||
id=content["meta_info"]["id"],
|
||||
created=created,
|
||||
object="text_completion",
|
||||
choices=[choice_data],
|
||||
model=request.model,
|
||||
)
|
||||
|
||||
stream_buffers[index] = stream_buffer
|
||||
n_prev_tokens[index] = n_prev_token
|
||||
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
|
||||
# Handle final usage chunk
|
||||
if request.stream_options and request.stream_options.include_usage:
|
||||
usage = self._calculate_streaming_usage_base(
|
||||
prompt_tokens, completion_tokens, cached_tokens, request.n
|
||||
)
|
||||
final_usage_chunk = CompletionStreamResponse(
|
||||
id=content["meta_info"]["id"],
|
||||
created=created,
|
||||
choices=[],
|
||||
model=request.model,
|
||||
usage=usage,
|
||||
)
|
||||
final_usage_data = final_usage_chunk.model_dump_json(
|
||||
exclude_none=True
|
||||
)
|
||||
yield f"data: {final_usage_data}\n\n"
|
||||
|
||||
except Exception as e:
|
||||
error = self.create_streaming_error_response(str(e))
|
||||
yield f"data: {error}\n\n"
|
||||
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
generate_stream_resp(),
|
||||
self._generate_completion_stream(adapted_request, request, raw_request),
|
||||
media_type="text/event-stream",
|
||||
background=self.tokenizer_manager.create_abort_task(adapted_request),
|
||||
)
|
||||
|
||||
async def _generate_completion_stream(
|
||||
self,
|
||||
adapted_request: GenerateReqInput,
|
||||
request: CompletionRequest,
|
||||
raw_request: Request,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate streaming completion response"""
|
||||
created = int(time.time())
|
||||
|
||||
# State tracking for streaming
|
||||
stream_buffers = {}
|
||||
n_prev_tokens = {}
|
||||
|
||||
# Usage tracking
|
||||
prompt_tokens = {}
|
||||
completion_tokens = {}
|
||||
cached_tokens = {}
|
||||
|
||||
try:
|
||||
async for content in self.tokenizer_manager.generate_request(
|
||||
adapted_request, raw_request
|
||||
):
|
||||
index = content.get("index", 0)
|
||||
|
||||
text = content["text"]
|
||||
prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
|
||||
completion_tokens[index] = content["meta_info"]["completion_tokens"]
|
||||
cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
|
||||
|
||||
stream_buffer = stream_buffers.get(index, "")
|
||||
# Handle echo for first chunk
|
||||
if not stream_buffer: # The first chunk
|
||||
if request.echo:
|
||||
echo_text = self._get_echo_text(request, index)
|
||||
text = echo_text + text
|
||||
|
||||
# Handle logprobs
|
||||
logprobs = None
|
||||
if request.logprobs is not None:
|
||||
# The first chunk and echo is enabled.
|
||||
if not stream_buffer and request.echo:
|
||||
input_token_logprobs = content["meta_info"][
|
||||
"input_token_logprobs"
|
||||
]
|
||||
input_top_logprobs = content["meta_info"]["input_top_logprobs"]
|
||||
else:
|
||||
input_token_logprobs = None
|
||||
input_top_logprobs = None
|
||||
|
||||
n_prev_token = n_prev_tokens.get(index, 0)
|
||||
logprobs = to_openai_style_logprobs(
|
||||
input_token_logprobs=input_token_logprobs,
|
||||
input_top_logprobs=input_top_logprobs,
|
||||
output_token_logprobs=content["meta_info"][
|
||||
"output_token_logprobs"
|
||||
][n_prev_token:],
|
||||
output_top_logprobs=content["meta_info"]["output_top_logprobs"][
|
||||
n_prev_token:
|
||||
],
|
||||
)
|
||||
n_prev_tokens[index] = len(
|
||||
content["meta_info"]["output_token_logprobs"]
|
||||
)
|
||||
|
||||
# Generate delta
|
||||
delta = text[len(stream_buffer) :]
|
||||
stream_buffers[index] = stream_buffer + delta
|
||||
finish_reason = content["meta_info"]["finish_reason"]
|
||||
|
||||
choice_data = CompletionResponseStreamChoice(
|
||||
index=index,
|
||||
text=delta,
|
||||
logprobs=logprobs,
|
||||
finish_reason=finish_reason["type"] if finish_reason else None,
|
||||
matched_stop=(
|
||||
finish_reason["matched"]
|
||||
if finish_reason and "matched" in finish_reason
|
||||
else None
|
||||
),
|
||||
)
|
||||
chunk = CompletionStreamResponse(
|
||||
id=content["meta_info"]["id"],
|
||||
created=created,
|
||||
object="text_completion",
|
||||
choices=[choice_data],
|
||||
model=request.model,
|
||||
)
|
||||
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
|
||||
# Handle final usage chunk
|
||||
if request.stream_options and request.stream_options.include_usage:
|
||||
usage = self._calculate_streaming_usage_base(
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
cached_tokens,
|
||||
request.n,
|
||||
)
|
||||
final_usage_chunk = CompletionStreamResponse(
|
||||
id=content["meta_info"]["id"],
|
||||
created=created,
|
||||
choices=[],
|
||||
model=request.model,
|
||||
usage=usage,
|
||||
)
|
||||
final_usage_data = final_usage_chunk.model_dump_json(exclude_none=True)
|
||||
yield f"data: {final_usage_data}\n\n"
|
||||
|
||||
except Exception as e:
|
||||
error = self.create_streaming_error_response(str(e))
|
||||
yield f"data: {error}\n\n"
|
||||
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def _handle_non_streaming_request(
|
||||
self,
|
||||
adapted_request: GenerateReqInput,
|
||||
@@ -334,7 +258,6 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
||||
request,
|
||||
ret,
|
||||
int(time.time()),
|
||||
cache_report=self.tokenizer_manager.server_args.enable_cache_report,
|
||||
)
|
||||
|
||||
return response
|
||||
@@ -344,7 +267,6 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
||||
request: CompletionRequest,
|
||||
ret: List[Dict[str, Any]],
|
||||
created: int,
|
||||
cache_report: bool = False,
|
||||
) -> CompletionResponse:
|
||||
"""Build completion response from generation results"""
|
||||
choices = []
|
||||
@@ -352,7 +274,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
||||
|
||||
# Prepare echo prompts if needed
|
||||
echo_prompts = []
|
||||
if (not isinstance(request, list)) and request.echo:
|
||||
if request.echo:
|
||||
echo_prompts = self._prepare_echo_prompts(request)
|
||||
echo = True
|
||||
|
||||
@@ -360,21 +282,13 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
||||
text = ret_item["text"]
|
||||
|
||||
# Handle echo
|
||||
if isinstance(request, list) and request[idx].echo:
|
||||
echo = True
|
||||
text = request[idx].prompt + text
|
||||
elif echo and not isinstance(request, list):
|
||||
if echo:
|
||||
prompt_index = idx // request.n
|
||||
text = echo_prompts[prompt_index] + text
|
||||
|
||||
# Handle logprobs
|
||||
logprobs = None
|
||||
if isinstance(request, list) and request[idx].logprobs is not None:
|
||||
logprobs = True
|
||||
elif (not isinstance(request, list)) and request.logprobs is not None:
|
||||
logprobs = True
|
||||
|
||||
if logprobs:
|
||||
if request.logprobs is not None:
|
||||
if echo:
|
||||
input_token_logprobs = ret_item["meta_info"]["input_token_logprobs"]
|
||||
input_top_logprobs = ret_item["meta_info"]["input_top_logprobs"]
|
||||
@@ -407,6 +321,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
||||
choices.append(choice_data)
|
||||
|
||||
# Calculate usage
|
||||
cache_report = self.tokenizer_manager.server_args.enable_cache_report
|
||||
usage = aggregate_token_usage(ret, request.n, cache_report)
|
||||
|
||||
return CompletionResponse(
|
||||
|
||||
@@ -54,35 +54,25 @@ class OpenAIServingEmbedding(OpenAIServingBase):
|
||||
return f"All items in input list must be integers"
|
||||
if item < 0:
|
||||
return f"Token ID at index {i} must be non-negative"
|
||||
elif isinstance(first_item, list):
|
||||
# List of lists (multiple token sequences)
|
||||
for i, item in enumerate(input):
|
||||
if not isinstance(item, list):
|
||||
return f"Input at index {i} must be a list"
|
||||
if not item:
|
||||
return f"Input at index {i} cannot be empty"
|
||||
if not all(isinstance(token, int) for token in item):
|
||||
return f"Input at index {i} must contain only integers"
|
||||
if any(token < 0 for token in item):
|
||||
return f"Input at index {i} contains negative token IDs"
|
||||
# Note: MultimodalEmbeddingInput validation would be handled by Pydantic
|
||||
|
||||
return None
|
||||
|
||||
def _convert_to_internal_request(
|
||||
self,
|
||||
request: EmbeddingRequest,
|
||||
request_id: str,
|
||||
) -> tuple[EmbeddingReqInput, Union[EmbeddingRequest, List[EmbeddingRequest]]]:
|
||||
) -> tuple[EmbeddingReqInput, EmbeddingRequest]:
|
||||
"""Convert OpenAI embedding request to internal format"""
|
||||
prompt = request.input
|
||||
|
||||
if isinstance(prompt, str):
|
||||
# Single string input
|
||||
prompt_kwargs = {"text": prompt}
|
||||
elif isinstance(prompt, list):
|
||||
if len(prompt) > 0 and isinstance(prompt[0], str):
|
||||
# List of strings
|
||||
prompt_kwargs = {"text": prompt}
|
||||
# List of strings - if it's a single string in a list, treat as single string
|
||||
if len(prompt) == 1:
|
||||
prompt_kwargs = {"text": prompt[0]}
|
||||
else:
|
||||
prompt_kwargs = {"text": prompt}
|
||||
elif len(prompt) > 0 and isinstance(prompt[0], MultimodalEmbeddingInput):
|
||||
# Handle multimodal embedding inputs
|
||||
texts = []
|
||||
@@ -94,7 +84,6 @@ class OpenAIServingEmbedding(OpenAIServingBase):
|
||||
|
||||
generate_prompts = []
|
||||
# Check if we have a chat template for multimodal embeddings
|
||||
# This would need to be passed in from the server configuration
|
||||
chat_template_name = getattr(
|
||||
self.tokenizer_manager, "chat_template_name", None
|
||||
)
|
||||
@@ -121,6 +110,7 @@ class OpenAIServingEmbedding(OpenAIServingBase):
|
||||
else:
|
||||
# Other types (should not happen but handle gracefully)
|
||||
prompt_kwargs = {"input_ids": prompt}
|
||||
|
||||
adapted_request = EmbeddingReqInput(
|
||||
**prompt_kwargs,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user