Refine OpenAI serving entrypoint to remove batch requests (#7372)
Signed-off-by: Xinyuan Tong <justinning0323@outlook.com> Co-authored-by: Chang Su <csu272@usc.edu>
This commit is contained in:
@@ -20,7 +20,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
from enum import auto
|
from enum import auto
|
||||||
|
|
||||||
from sglang.srt.openai_api.protocol import ChatCompletionRequest
|
from sglang.srt.entrypoints.openai.protocol import CompletionRequest
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
completion_template_name = None
|
completion_template_name = None
|
||||||
@@ -116,7 +116,7 @@ def is_completion_template_defined() -> bool:
|
|||||||
return completion_template_name is not None
|
return completion_template_name is not None
|
||||||
|
|
||||||
|
|
||||||
def generate_completion_prompt_from_request(request: ChatCompletionRequest) -> str:
|
def generate_completion_prompt_from_request(request: CompletionRequest) -> str:
|
||||||
global completion_template_name
|
global completion_template_name
|
||||||
if request.suffix == "":
|
if request.suffix == "":
|
||||||
return request.prompt
|
return request.prompt
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, Optional, Union
|
||||||
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from fastapi.responses import ORJSONResponse, StreamingResponse
|
from fastapi.responses import ORJSONResponse, StreamingResponse
|
||||||
@@ -37,7 +37,7 @@ class OpenAIServingBase(ABC):
|
|||||||
|
|
||||||
# Convert to internal format
|
# Convert to internal format
|
||||||
adapted_request, processed_request = self._convert_to_internal_request(
|
adapted_request, processed_request = self._convert_to_internal_request(
|
||||||
request, self._generate_request_id_base(request)
|
request
|
||||||
)
|
)
|
||||||
|
|
||||||
# Note(Xinyuan): raw_request below is only used for detecting the connection of the client
|
# Note(Xinyuan): raw_request below is only used for detecting the connection of the client
|
||||||
@@ -74,10 +74,7 @@ class OpenAIServingBase(ABC):
|
|||||||
def _convert_to_internal_request(
|
def _convert_to_internal_request(
|
||||||
self,
|
self,
|
||||||
request: OpenAIServingRequest,
|
request: OpenAIServingRequest,
|
||||||
request_id: str,
|
) -> tuple[GenerateReqInput, OpenAIServingRequest]:
|
||||||
) -> tuple[
|
|
||||||
GenerateReqInput, Union[OpenAIServingRequest, List[OpenAIServingRequest]]
|
|
||||||
]:
|
|
||||||
"""Convert OpenAI request to internal format"""
|
"""Convert OpenAI request to internal format"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
||||||
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
@@ -52,137 +52,56 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
def _request_id_prefix(self) -> str:
|
def _request_id_prefix(self) -> str:
|
||||||
return "chatcmpl-"
|
return "chatcmpl-"
|
||||||
|
|
||||||
def _validate_request(self, request: ChatCompletionRequest) -> Optional[str]:
|
|
||||||
"""Validate chat messages format and content"""
|
|
||||||
if not (messages := request.messages):
|
|
||||||
return "Messages cannot be empty"
|
|
||||||
|
|
||||||
# Check for alternating user/assistant pattern (optional validation)
|
|
||||||
roles = [msg.role for msg in messages]
|
|
||||||
|
|
||||||
# First message should typically be from user or system
|
|
||||||
if roles[0] not in ["user", "system"]:
|
|
||||||
return "First message should be from 'user' or 'system'"
|
|
||||||
|
|
||||||
# Check for consecutive assistant messages (which might indicate an error)
|
|
||||||
for i in range(1, len(roles)):
|
|
||||||
if roles[i] == "assistant" and roles[i - 1] == "assistant":
|
|
||||||
# This is actually allowed in some cases, so just warn
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Validate message content
|
|
||||||
for i, msg in enumerate(messages):
|
|
||||||
if msg.role == "user":
|
|
||||||
if not msg.content:
|
|
||||||
return f"User message at index {i} has no content"
|
|
||||||
elif msg.role == "assistant":
|
|
||||||
# Assistant messages can have no content if they have tool_calls
|
|
||||||
if not msg.content and not getattr(msg, "tool_calls", None):
|
|
||||||
return (
|
|
||||||
f"Assistant message at index {i} has no content or tool calls"
|
|
||||||
)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _convert_to_internal_request(
|
def _convert_to_internal_request(
|
||||||
self,
|
self,
|
||||||
all_requests: List[ChatCompletionRequest],
|
request: ChatCompletionRequest,
|
||||||
request_ids: List[str],
|
) -> tuple[GenerateReqInput, ChatCompletionRequest]:
|
||||||
) -> tuple[
|
|
||||||
GenerateReqInput, Union[ChatCompletionRequest, List[ChatCompletionRequest]]
|
|
||||||
]:
|
|
||||||
"""Convert OpenAI chat completion request to internal format"""
|
"""Convert OpenAI chat completion request to internal format"""
|
||||||
input_ids = []
|
|
||||||
prompts = []
|
|
||||||
sampling_params_list = []
|
|
||||||
image_data_list = []
|
|
||||||
audio_data_list = []
|
|
||||||
return_logprobs = []
|
|
||||||
logprob_start_lens = []
|
|
||||||
top_logprobs_nums = []
|
|
||||||
modalities_list = []
|
|
||||||
lora_paths = []
|
|
||||||
|
|
||||||
is_multimodal = self.tokenizer_manager.model_config.is_multimodal
|
is_multimodal = self.tokenizer_manager.model_config.is_multimodal
|
||||||
|
|
||||||
for request in all_requests:
|
# Process messages and apply chat template
|
||||||
# Process messages and apply chat template
|
(
|
||||||
(
|
prompt,
|
||||||
prompt,
|
prompt_ids,
|
||||||
prompt_ids,
|
image_data,
|
||||||
image_data,
|
audio_data,
|
||||||
audio_data,
|
modalities,
|
||||||
modalities,
|
stop,
|
||||||
stop,
|
tool_call_constraint,
|
||||||
tool_call_constraint,
|
) = self._process_messages(request, is_multimodal)
|
||||||
) = self._process_messages(request, is_multimodal)
|
|
||||||
|
|
||||||
input_ids.append(prompt_ids)
|
# Build sampling parameters
|
||||||
prompts.append(prompt)
|
sampling_params = self._build_sampling_params(
|
||||||
return_logprobs.append(request.logprobs)
|
request, stop, tool_call_constraint
|
||||||
logprob_start_lens.append(-1)
|
)
|
||||||
top_logprobs_nums.append(request.top_logprobs or 0)
|
|
||||||
lora_paths.append(request.lora_path)
|
|
||||||
|
|
||||||
# Build sampling parameters
|
|
||||||
sampling_params = self._build_sampling_params(
|
|
||||||
request, stop, tool_call_constraint
|
|
||||||
)
|
|
||||||
sampling_params_list.append(sampling_params)
|
|
||||||
|
|
||||||
image_data_list.append(image_data)
|
|
||||||
audio_data_list.append(audio_data)
|
|
||||||
modalities_list.append(modalities)
|
|
||||||
|
|
||||||
# Handle single vs multiple requests
|
# Handle single vs multiple requests
|
||||||
if len(all_requests) == 1:
|
if is_multimodal:
|
||||||
if is_multimodal:
|
prompt_kwargs = {"text": prompt}
|
||||||
prompt_kwargs = {"text": prompts[0]}
|
|
||||||
else:
|
|
||||||
if isinstance(input_ids[0], str):
|
|
||||||
prompt_kwargs = {"text": input_ids[0]}
|
|
||||||
else:
|
|
||||||
prompt_kwargs = {"input_ids": input_ids[0]}
|
|
||||||
|
|
||||||
sampling_params_list = sampling_params_list[0]
|
|
||||||
image_data_list = image_data_list[0]
|
|
||||||
audio_data_list = audio_data_list[0]
|
|
||||||
return_logprobs = return_logprobs[0]
|
|
||||||
logprob_start_lens = logprob_start_lens[0]
|
|
||||||
top_logprobs_nums = top_logprobs_nums[0]
|
|
||||||
modalities_list = modalities_list[0]
|
|
||||||
lora_paths = lora_paths[0]
|
|
||||||
request_ids = request_ids[0]
|
|
||||||
else:
|
else:
|
||||||
if is_multimodal:
|
if isinstance(prompt_ids, str):
|
||||||
prompt_kwargs = {"text": prompts}
|
prompt_kwargs = {"text": prompt_ids}
|
||||||
else:
|
else:
|
||||||
if isinstance(input_ids[0], str):
|
prompt_kwargs = {"input_ids": prompt_ids}
|
||||||
prompt_kwargs = {"text": input_ids}
|
|
||||||
else:
|
|
||||||
prompt_kwargs = {"input_ids": input_ids}
|
|
||||||
|
|
||||||
adapted_request = GenerateReqInput(
|
adapted_request = GenerateReqInput(
|
||||||
**prompt_kwargs,
|
**prompt_kwargs,
|
||||||
image_data=image_data_list,
|
image_data=image_data,
|
||||||
audio_data=audio_data_list,
|
audio_data=audio_data,
|
||||||
sampling_params=sampling_params_list,
|
sampling_params=sampling_params,
|
||||||
return_logprob=return_logprobs,
|
return_logprob=request.logprobs,
|
||||||
logprob_start_len=logprob_start_lens,
|
logprob_start_len=-1,
|
||||||
top_logprobs_num=top_logprobs_nums,
|
top_logprobs_num=request.top_logprobs or 0,
|
||||||
stream=all_requests[0].stream,
|
stream=request.stream,
|
||||||
return_text_in_logprobs=True,
|
return_text_in_logprobs=True,
|
||||||
rid=request_ids,
|
modalities=modalities,
|
||||||
modalities=modalities_list,
|
lora_path=request.lora_path,
|
||||||
lora_path=lora_paths,
|
bootstrap_host=request.bootstrap_host,
|
||||||
bootstrap_host=all_requests[0].bootstrap_host,
|
bootstrap_port=request.bootstrap_port,
|
||||||
bootstrap_port=all_requests[0].bootstrap_port,
|
bootstrap_room=request.bootstrap_room,
|
||||||
bootstrap_room=all_requests[0].bootstrap_room,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return adapted_request, (
|
return adapted_request, request
|
||||||
all_requests if len(all_requests) > 1 else all_requests[0]
|
|
||||||
)
|
|
||||||
|
|
||||||
def _process_messages(
|
def _process_messages(
|
||||||
self, request: ChatCompletionRequest, is_multimodal: bool
|
self, request: ChatCompletionRequest, is_multimodal: bool
|
||||||
@@ -457,55 +376,138 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
raw_request: Request,
|
raw_request: Request,
|
||||||
) -> StreamingResponse:
|
) -> StreamingResponse:
|
||||||
"""Handle streaming chat completion request"""
|
"""Handle streaming chat completion request"""
|
||||||
|
return StreamingResponse(
|
||||||
|
self._generate_chat_stream(adapted_request, request, raw_request),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
background=self.tokenizer_manager.create_abort_task(adapted_request),
|
||||||
|
)
|
||||||
|
|
||||||
async def generate_stream_resp():
|
async def _generate_chat_stream(
|
||||||
parser_dict = {}
|
self,
|
||||||
reasoning_parser_dict = {}
|
adapted_request: GenerateReqInput,
|
||||||
tool_call_first = True
|
request: ChatCompletionRequest,
|
||||||
is_firsts = {}
|
raw_request: Request,
|
||||||
stream_buffers = {}
|
) -> AsyncGenerator[str, None]:
|
||||||
n_prev_tokens = {}
|
"""Generate streaming chat completion response"""
|
||||||
prompt_tokens = {}
|
# Parsers for tool calls and reasoning
|
||||||
completion_tokens = {}
|
parser_dict = {}
|
||||||
cached_tokens = {}
|
reasoning_parser_dict = {}
|
||||||
|
|
||||||
try:
|
# State tracking for streaming
|
||||||
async for content in self.tokenizer_manager.generate_request(
|
is_firsts = {}
|
||||||
adapted_request, raw_request
|
stream_buffers = {}
|
||||||
):
|
n_prev_tokens = {}
|
||||||
index = content.get("index", 0)
|
|
||||||
|
|
||||||
is_first = is_firsts.get(index, True)
|
# Usage tracking
|
||||||
stream_buffer = stream_buffers.get(index, "")
|
prompt_tokens = {}
|
||||||
n_prev_token = n_prev_tokens.get(index, 0)
|
completion_tokens = {}
|
||||||
|
cached_tokens = {}
|
||||||
|
|
||||||
prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
|
try:
|
||||||
completion_tokens[index] = content["meta_info"]["completion_tokens"]
|
async for content in self.tokenizer_manager.generate_request(
|
||||||
cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
|
adapted_request, raw_request
|
||||||
|
):
|
||||||
|
index = content.get("index", 0)
|
||||||
|
|
||||||
# Handle logprobs
|
prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
|
||||||
choice_logprobs = None
|
completion_tokens[index] = content["meta_info"]["completion_tokens"]
|
||||||
if request.logprobs:
|
cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
|
||||||
choice_logprobs = self._process_streaming_logprobs(
|
|
||||||
content, n_prev_token
|
|
||||||
)
|
|
||||||
n_prev_token = len(
|
|
||||||
content["meta_info"]["output_token_logprobs"]
|
|
||||||
)
|
|
||||||
|
|
||||||
finish_reason = content["meta_info"]["finish_reason"]
|
# Handle logprobs
|
||||||
finish_reason_type = (
|
choice_logprobs = None
|
||||||
finish_reason["type"] if finish_reason else None
|
if request.logprobs:
|
||||||
|
choice_logprobs = self._process_streaming_logprobs(
|
||||||
|
content, n_prev_tokens.get(index, 0)
|
||||||
|
)
|
||||||
|
n_prev_tokens[index] = len(
|
||||||
|
content["meta_info"]["output_token_logprobs"]
|
||||||
)
|
)
|
||||||
|
|
||||||
# First chunk with role
|
finish_reason = content["meta_info"]["finish_reason"]
|
||||||
if is_first:
|
finish_reason_type = finish_reason["type"] if finish_reason else None
|
||||||
is_first = False
|
|
||||||
delta = DeltaMessage(role="assistant")
|
# First chunk with role
|
||||||
|
if is_firsts.get(index, True):
|
||||||
|
is_firsts[index] = False
|
||||||
|
delta = DeltaMessage(role="assistant", content="")
|
||||||
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
|
index=index,
|
||||||
|
delta=delta,
|
||||||
|
finish_reason=finish_reason_type,
|
||||||
|
matched_stop=(
|
||||||
|
finish_reason["matched"]
|
||||||
|
if finish_reason and "matched" in finish_reason
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
logprobs=choice_logprobs,
|
||||||
|
)
|
||||||
|
chunk = ChatCompletionStreamResponse(
|
||||||
|
id=content["meta_info"]["id"],
|
||||||
|
created=int(time.time()),
|
||||||
|
choices=[choice_data],
|
||||||
|
model=request.model,
|
||||||
|
)
|
||||||
|
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||||
|
|
||||||
|
# Process content delta
|
||||||
|
stream_buffer = stream_buffers.get(index, "")
|
||||||
|
delta = content["text"][len(stream_buffer) :]
|
||||||
|
stream_buffers[index] = stream_buffer + delta
|
||||||
|
|
||||||
|
# Handle reasoning content
|
||||||
|
enable_thinking = getattr(request, "chat_template_kwargs", {}).get(
|
||||||
|
"enable_thinking", True
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
self.tokenizer_manager.server_args.reasoning_parser
|
||||||
|
and request.separate_reasoning
|
||||||
|
and enable_thinking
|
||||||
|
):
|
||||||
|
reasoning_text, delta = self._process_reasoning_stream(
|
||||||
|
index, delta, reasoning_parser_dict, content, request
|
||||||
|
)
|
||||||
|
if reasoning_text:
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
index=index,
|
index=index,
|
||||||
delta=delta,
|
delta=DeltaMessage(reasoning_content=reasoning_text),
|
||||||
finish_reason=finish_reason_type,
|
finish_reason=finish_reason_type,
|
||||||
|
)
|
||||||
|
chunk = ChatCompletionStreamResponse(
|
||||||
|
id=content["meta_info"]["id"],
|
||||||
|
created=int(time.time()),
|
||||||
|
choices=[choice_data],
|
||||||
|
model=request.model,
|
||||||
|
)
|
||||||
|
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||||
|
|
||||||
|
if not delta:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Handle tool calls
|
||||||
|
if request.tool_choice != "none" and request.tools:
|
||||||
|
async for chunk in self._process_tool_call_stream(
|
||||||
|
index,
|
||||||
|
delta,
|
||||||
|
parser_dict,
|
||||||
|
content,
|
||||||
|
request,
|
||||||
|
finish_reason_type,
|
||||||
|
):
|
||||||
|
yield chunk
|
||||||
|
else:
|
||||||
|
# Regular content
|
||||||
|
if delta or not (
|
||||||
|
request.stream_options and request.stream_options.include_usage
|
||||||
|
):
|
||||||
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
|
index=index,
|
||||||
|
delta=DeltaMessage(content=delta if delta else None),
|
||||||
|
finish_reason=(
|
||||||
|
None
|
||||||
|
if request.stream_options
|
||||||
|
and request.stream_options.include_usage
|
||||||
|
else finish_reason_type
|
||||||
|
),
|
||||||
matched_stop=(
|
matched_stop=(
|
||||||
finish_reason["matched"]
|
finish_reason["matched"]
|
||||||
if finish_reason and "matched" in finish_reason
|
if finish_reason and "matched" in finish_reason
|
||||||
@@ -521,121 +523,49 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
)
|
)
|
||||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||||
|
|
||||||
# Process content delta
|
# Final chunk with finish_reason
|
||||||
delta = content["text"][len(stream_buffer) :]
|
finish_reason_chunk = ChatCompletionStreamResponse(
|
||||||
new_stream_buffer = stream_buffer + delta
|
id=content["meta_info"]["id"],
|
||||||
|
created=int(time.time()),
|
||||||
# Handle reasoning content
|
choices=[
|
||||||
enable_thinking = getattr(request, "chat_template_kwargs", {}).get(
|
ChatCompletionResponseStreamChoice(
|
||||||
"enable_thinking", True
|
index=index,
|
||||||
|
delta=DeltaMessage(),
|
||||||
|
finish_reason=finish_reason_type,
|
||||||
|
matched_stop=(
|
||||||
|
finish_reason["matched"]
|
||||||
|
if finish_reason and "matched" in finish_reason
|
||||||
|
else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
if (
|
],
|
||||||
self.tokenizer_manager.server_args.reasoning_parser
|
model=request.model,
|
||||||
and request.separate_reasoning
|
usage=None,
|
||||||
and enable_thinking
|
)
|
||||||
):
|
yield f"data: {finish_reason_chunk.model_dump_json()}\n\n"
|
||||||
reasoning_text, delta = self._process_reasoning_stream(
|
|
||||||
index, delta, reasoning_parser_dict, content, request
|
|
||||||
)
|
|
||||||
if reasoning_text:
|
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
|
||||||
index=index,
|
|
||||||
delta=DeltaMessage(reasoning_content=reasoning_text),
|
|
||||||
finish_reason=finish_reason_type,
|
|
||||||
)
|
|
||||||
chunk = ChatCompletionStreamResponse(
|
|
||||||
id=content["meta_info"]["id"],
|
|
||||||
created=int(time.time()),
|
|
||||||
choices=[choice_data],
|
|
||||||
model=request.model,
|
|
||||||
)
|
|
||||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
|
||||||
|
|
||||||
if not delta:
|
# Additional usage chunk
|
||||||
stream_buffers[index] = new_stream_buffer
|
if request.stream_options and request.stream_options.include_usage:
|
||||||
is_firsts[index] = is_first
|
usage = self._calculate_streaming_usage_base(
|
||||||
n_prev_tokens[index] = n_prev_token
|
prompt_tokens,
|
||||||
continue
|
completion_tokens,
|
||||||
|
cached_tokens,
|
||||||
# Handle tool calls
|
request.n,
|
||||||
if request.tool_choice != "none" and request.tools:
|
)
|
||||||
async for chunk in self._process_tool_call_stream(
|
usage_chunk = ChatCompletionStreamResponse(
|
||||||
index,
|
|
||||||
delta,
|
|
||||||
parser_dict,
|
|
||||||
content,
|
|
||||||
request,
|
|
||||||
finish_reason_type,
|
|
||||||
):
|
|
||||||
yield chunk
|
|
||||||
else:
|
|
||||||
# Regular content
|
|
||||||
if delta or not (
|
|
||||||
request.stream_options
|
|
||||||
and request.stream_options.include_usage
|
|
||||||
):
|
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
|
||||||
index=index,
|
|
||||||
delta=DeltaMessage(content=delta if delta else None),
|
|
||||||
finish_reason=(
|
|
||||||
None
|
|
||||||
if request.stream_options
|
|
||||||
and request.stream_options.include_usage
|
|
||||||
else finish_reason_type
|
|
||||||
),
|
|
||||||
matched_stop=(
|
|
||||||
finish_reason["matched"]
|
|
||||||
if finish_reason and "matched" in finish_reason
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
logprobs=choice_logprobs,
|
|
||||||
)
|
|
||||||
chunk = ChatCompletionStreamResponse(
|
|
||||||
id=content["meta_info"]["id"],
|
|
||||||
created=int(time.time()),
|
|
||||||
choices=[choice_data],
|
|
||||||
model=request.model,
|
|
||||||
)
|
|
||||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
|
||||||
|
|
||||||
stream_buffers[index] = new_stream_buffer
|
|
||||||
is_firsts[index] = is_first
|
|
||||||
n_prev_tokens[index] = n_prev_token
|
|
||||||
|
|
||||||
# Final chunk with usage
|
|
||||||
if request.stream_options and request.stream_options.include_usage:
|
|
||||||
usage = self._calculate_streaming_usage_base(
|
|
||||||
prompt_tokens, completion_tokens, cached_tokens, request.n
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
usage = None
|
|
||||||
|
|
||||||
final_chunk = ChatCompletionStreamResponse(
|
|
||||||
id=content["meta_info"]["id"],
|
id=content["meta_info"]["id"],
|
||||||
created=int(time.time()),
|
created=int(time.time()),
|
||||||
choices=[
|
choices=[], # Empty choices array as per OpenAI spec
|
||||||
ChatCompletionResponseStreamChoice(
|
|
||||||
index=index,
|
|
||||||
delta=DeltaMessage(),
|
|
||||||
finish_reason=finish_reason_type,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
model=request.model,
|
model=request.model,
|
||||||
usage=usage,
|
usage=usage,
|
||||||
)
|
)
|
||||||
yield f"data: {final_chunk.model_dump_json()}\n\n"
|
yield f"data: {usage_chunk.model_dump_json()}\n\n"
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error = self.create_streaming_error_response(str(e))
|
error = self.create_streaming_error_response(str(e))
|
||||||
yield f"data: {error}\n\n"
|
yield f"data: {error}\n\n"
|
||||||
|
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
return StreamingResponse(
|
|
||||||
generate_stream_resp(),
|
|
||||||
media_type="text/event-stream",
|
|
||||||
background=self.tokenizer_manager.create_abort_task(adapted_request),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _handle_non_streaming_request(
|
async def _handle_non_streaming_request(
|
||||||
self,
|
self,
|
||||||
@@ -658,9 +588,6 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
request,
|
request,
|
||||||
ret,
|
ret,
|
||||||
int(time.time()),
|
int(time.time()),
|
||||||
cache_report=self.tokenizer_manager.server_args.enable_cache_report,
|
|
||||||
tool_call_parser=self.tokenizer_manager.server_args.tool_call_parser,
|
|
||||||
reasoning_parser=self.tokenizer_manager.server_args.reasoning_parser,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
@@ -670,9 +597,6 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
ret: List[Dict[str, Any]],
|
ret: List[Dict[str, Any]],
|
||||||
created: int,
|
created: int,
|
||||||
cache_report: bool = False,
|
|
||||||
tool_call_parser: Optional[str] = None,
|
|
||||||
reasoning_parser: Optional[str] = None,
|
|
||||||
) -> ChatCompletionResponse:
|
) -> ChatCompletionResponse:
|
||||||
"""Build chat completion response from generation results"""
|
"""Build chat completion response from generation results"""
|
||||||
choices = []
|
choices = []
|
||||||
@@ -691,6 +615,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
enable_thinking = getattr(request, "chat_template_kwargs", {}).get(
|
enable_thinking = getattr(request, "chat_template_kwargs", {}).get(
|
||||||
"enable_thinking", True
|
"enable_thinking", True
|
||||||
)
|
)
|
||||||
|
reasoning_parser = self.tokenizer_manager.server_args.reasoning_parser
|
||||||
if reasoning_parser and request.separate_reasoning and enable_thinking:
|
if reasoning_parser and request.separate_reasoning and enable_thinking:
|
||||||
try:
|
try:
|
||||||
parser = ReasoningParser(
|
parser = ReasoningParser(
|
||||||
@@ -708,6 +633,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
# Handle tool calls
|
# Handle tool calls
|
||||||
tool_calls = None
|
tool_calls = None
|
||||||
if request.tool_choice != "none" and request.tools:
|
if request.tool_choice != "none" and request.tools:
|
||||||
|
tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser
|
||||||
tool_calls, text, finish_reason = self._process_tool_calls(
|
tool_calls, text, finish_reason = self._process_tool_calls(
|
||||||
text, request.tools, tool_call_parser, finish_reason
|
text, request.tools, tool_call_parser, finish_reason
|
||||||
)
|
)
|
||||||
@@ -731,6 +657,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
choices.append(choice_data)
|
choices.append(choice_data)
|
||||||
|
|
||||||
# Calculate usage
|
# Calculate usage
|
||||||
|
cache_report = self.tokenizer_manager.server_args.enable_cache_report
|
||||||
usage = aggregate_token_usage(ret, request.n, cache_report)
|
usage = aggregate_token_usage(ret, request.n, cache_report)
|
||||||
|
|
||||||
return ChatCompletionResponse(
|
return ChatCompletionResponse(
|
||||||
@@ -810,7 +737,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
text, call_info_list = parser.parse_non_stream(text)
|
text, call_info_list = parser.parse_non_stream(text)
|
||||||
tool_calls = [
|
tool_calls = [
|
||||||
ToolCall(
|
ToolCall(
|
||||||
id=f"call_{base64.urlsafe_b64encode(uuid.uuid4().bytes).rstrip(b'=').decode()}",
|
id=f"call_{uuid.uuid4().hex[:24]}",
|
||||||
function=FunctionResponse(
|
function=FunctionResponse(
|
||||||
name=call_info.name, arguments=call_info.parameters
|
name=call_info.name, arguments=call_info.parameters
|
||||||
),
|
),
|
||||||
@@ -894,6 +821,16 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
|
|
||||||
# Yield tool calls
|
# Yield tool calls
|
||||||
for call_item in calls:
|
for call_item in calls:
|
||||||
|
# Tool call ID should be generated only once per tool call
|
||||||
|
if call_item.name:
|
||||||
|
# First chunk: include ID and function name
|
||||||
|
tool_call_id = f"call_{uuid.uuid4().hex[:24]}"
|
||||||
|
function_name = call_item.name
|
||||||
|
else:
|
||||||
|
# Subsequent chunks: null ID and name for argument deltas
|
||||||
|
tool_call_id = None
|
||||||
|
function_name = None
|
||||||
|
|
||||||
if finish_reason_type == "stop":
|
if finish_reason_type == "stop":
|
||||||
# Handle remaining arguments
|
# Handle remaining arguments
|
||||||
latest_delta_len = 0
|
latest_delta_len = 0
|
||||||
@@ -912,10 +849,10 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
finish_reason_type = "tool_calls"
|
finish_reason_type = "tool_calls"
|
||||||
|
|
||||||
tool_call = ToolCall(
|
tool_call = ToolCall(
|
||||||
id=f"call_{base64.urlsafe_b64encode(uuid.uuid4().bytes).rstrip(b'=').decode()}",
|
id=tool_call_id,
|
||||||
index=call_item.tool_index,
|
index=call_item.tool_index,
|
||||||
function=FunctionResponse(
|
function=FunctionResponse(
|
||||||
name=call_item.name,
|
name=function_name,
|
||||||
arguments=call_item.parameters,
|
arguments=call_item.parameters,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, AsyncGenerator, Dict, List, Union
|
||||||
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
@@ -23,6 +24,8 @@ from sglang.srt.entrypoints.openai.utils import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class OpenAIServingCompletion(OpenAIServingBase):
|
class OpenAIServingCompletion(OpenAIServingBase):
|
||||||
"""Handler for completion requests"""
|
"""Handler for completion requests"""
|
||||||
@@ -30,134 +33,54 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
|||||||
def _request_id_prefix(self) -> str:
|
def _request_id_prefix(self) -> str:
|
||||||
return "cmpl-"
|
return "cmpl-"
|
||||||
|
|
||||||
def _validate_request(self, request: CompletionRequest) -> Optional[str]:
|
|
||||||
"""Validate completion prompt format and content"""
|
|
||||||
if not (prompt := request.prompt):
|
|
||||||
return "Prompt cannot be None"
|
|
||||||
|
|
||||||
if isinstance(prompt, str):
|
|
||||||
if not prompt.strip():
|
|
||||||
return "Prompt cannot be empty or whitespace only"
|
|
||||||
elif isinstance(prompt, list):
|
|
||||||
if not prompt:
|
|
||||||
return "Prompt list cannot be empty"
|
|
||||||
|
|
||||||
# Check if it's a list of strings
|
|
||||||
if all(isinstance(item, str) for item in prompt):
|
|
||||||
for i, item in enumerate(prompt):
|
|
||||||
if not item.strip():
|
|
||||||
return f"Prompt at index {i} cannot be empty or whitespace only"
|
|
||||||
|
|
||||||
# Check if it's a list of token IDs (integers)
|
|
||||||
elif all(isinstance(item, int) for item in prompt):
|
|
||||||
if any(item < 0 for item in prompt):
|
|
||||||
return "Token IDs must be non-negative"
|
|
||||||
|
|
||||||
# Check if it's a list of lists (multiple token sequences)
|
|
||||||
elif all(isinstance(item, list) for item in prompt):
|
|
||||||
for i, item in enumerate(prompt):
|
|
||||||
if not item:
|
|
||||||
return f"Token sequence at index {i} cannot be empty"
|
|
||||||
if not all(isinstance(token, int) for token in item):
|
|
||||||
return f"Token sequence at index {i} must contain only integers"
|
|
||||||
if any(token < 0 for token in item):
|
|
||||||
return (
|
|
||||||
f"Token sequence at index {i} contains negative token IDs"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return "Prompt must be string, list of strings, list of integers, or list of integer lists"
|
|
||||||
else:
|
|
||||||
return "Prompt must be string or list"
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _convert_to_internal_request(
|
def _convert_to_internal_request(
|
||||||
self,
|
self,
|
||||||
all_requests: List[CompletionRequest],
|
request: CompletionRequest,
|
||||||
request_ids: List[str],
|
) -> tuple[GenerateReqInput, CompletionRequest]:
|
||||||
) -> tuple[GenerateReqInput, Union[CompletionRequest, List[CompletionRequest]]]:
|
|
||||||
"""Convert OpenAI completion request to internal format"""
|
"""Convert OpenAI completion request to internal format"""
|
||||||
# Validate batch requests
|
# NOTE: with openai API, the prompt's logprobs are always not computed
|
||||||
if len(all_requests) > 1:
|
if request.echo and request.logprobs:
|
||||||
first_prompt_type = type(all_requests[0].prompt)
|
logger.warning(
|
||||||
for request in all_requests:
|
"Echo is not compatible with logprobs. "
|
||||||
assert (
|
"To compute logprobs of input prompt, please use the native /generate API."
|
||||||
type(request.prompt) is first_prompt_type
|
|
||||||
), "All prompts must be of the same type in file input settings"
|
|
||||||
if request.n > 1:
|
|
||||||
raise ValueError(
|
|
||||||
"Parallel sampling is not supported for completions from files"
|
|
||||||
)
|
|
||||||
|
|
||||||
prompts = []
|
|
||||||
sampling_params_list = []
|
|
||||||
return_logprobs = []
|
|
||||||
logprob_start_lens = []
|
|
||||||
top_logprobs_nums = []
|
|
||||||
lora_paths = []
|
|
||||||
|
|
||||||
for request in all_requests:
|
|
||||||
# Process prompt
|
|
||||||
prompt = request.prompt
|
|
||||||
if is_completion_template_defined():
|
|
||||||
prompt = generate_completion_prompt_from_request(request)
|
|
||||||
|
|
||||||
prompts.append(prompt)
|
|
||||||
|
|
||||||
lora_paths.append(request.lora_path)
|
|
||||||
|
|
||||||
# Set logprob start length based on echo and logprobs
|
|
||||||
if request.echo and request.logprobs:
|
|
||||||
current_logprob_start_len = 0
|
|
||||||
else:
|
|
||||||
current_logprob_start_len = -1
|
|
||||||
|
|
||||||
# Build sampling parameters
|
|
||||||
sampling_params = self._build_sampling_params(request)
|
|
||||||
sampling_params_list.append(sampling_params)
|
|
||||||
|
|
||||||
return_logprobs.append(request.logprobs is not None)
|
|
||||||
logprob_start_lens.append(current_logprob_start_len)
|
|
||||||
top_logprobs_nums.append(
|
|
||||||
request.logprobs if request.logprobs is not None else 0
|
|
||||||
)
|
)
|
||||||
|
# Process prompt
|
||||||
|
prompt = request.prompt
|
||||||
|
if is_completion_template_defined():
|
||||||
|
prompt = generate_completion_prompt_from_request(request)
|
||||||
|
|
||||||
# Handle single vs multiple requests
|
# Set logprob start length based on echo and logprobs
|
||||||
if len(all_requests) == 1:
|
if request.echo and request.logprobs:
|
||||||
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
|
logprob_start_len = 0
|
||||||
prompt_kwargs = {"text": prompts[0]}
|
|
||||||
else:
|
|
||||||
prompt_kwargs = {"input_ids": prompts[0]}
|
|
||||||
sampling_params_list = sampling_params_list[0]
|
|
||||||
return_logprobs = return_logprobs[0]
|
|
||||||
logprob_start_lens = logprob_start_lens[0]
|
|
||||||
top_logprobs_nums = top_logprobs_nums[0]
|
|
||||||
lora_paths = lora_paths[0]
|
|
||||||
request_ids = request_ids[0]
|
|
||||||
else:
|
else:
|
||||||
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
|
logprob_start_len = -1
|
||||||
prompt_kwargs = {"text": prompts}
|
|
||||||
else:
|
# Build sampling parameters
|
||||||
prompt_kwargs = {"input_ids": prompts}
|
sampling_params = self._build_sampling_params(request)
|
||||||
|
|
||||||
|
# Determine prompt format
|
||||||
|
if isinstance(prompt, str) or (
|
||||||
|
isinstance(prompt, list) and isinstance(prompt[0], str)
|
||||||
|
):
|
||||||
|
prompt_kwargs = {"text": prompt}
|
||||||
|
else:
|
||||||
|
prompt_kwargs = {"input_ids": prompt}
|
||||||
|
|
||||||
adapted_request = GenerateReqInput(
|
adapted_request = GenerateReqInput(
|
||||||
**prompt_kwargs,
|
**prompt_kwargs,
|
||||||
sampling_params=sampling_params_list,
|
sampling_params=sampling_params,
|
||||||
return_logprob=return_logprobs,
|
return_logprob=request.logprobs is not None,
|
||||||
top_logprobs_num=top_logprobs_nums,
|
top_logprobs_num=request.logprobs if request.logprobs is not None else 0,
|
||||||
logprob_start_len=logprob_start_lens,
|
logprob_start_len=logprob_start_len,
|
||||||
return_text_in_logprobs=True,
|
return_text_in_logprobs=True,
|
||||||
stream=all_requests[0].stream,
|
stream=request.stream,
|
||||||
rid=request_ids,
|
lora_path=request.lora_path,
|
||||||
lora_path=lora_paths,
|
bootstrap_host=request.bootstrap_host,
|
||||||
bootstrap_host=all_requests[0].bootstrap_host,
|
bootstrap_port=request.bootstrap_port,
|
||||||
bootstrap_port=all_requests[0].bootstrap_port,
|
bootstrap_room=request.bootstrap_room,
|
||||||
bootstrap_room=all_requests[0].bootstrap_room,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return adapted_request, (
|
return adapted_request, request
|
||||||
all_requests if len(all_requests) > 1 else all_requests[0]
|
|
||||||
)
|
|
||||||
|
|
||||||
def _build_sampling_params(self, request: CompletionRequest) -> Dict[str, Any]:
|
def _build_sampling_params(self, request: CompletionRequest) -> Dict[str, Any]:
|
||||||
"""Build sampling parameters for the request"""
|
"""Build sampling parameters for the request"""
|
||||||
@@ -184,9 +107,6 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
|||||||
"logit_bias": request.logit_bias,
|
"logit_bias": request.logit_bias,
|
||||||
}
|
}
|
||||||
|
|
||||||
# No additional completion-specific parameters needed currently
|
|
||||||
# (json_schema is already handled in base method)
|
|
||||||
|
|
||||||
return sampling_params
|
return sampling_params
|
||||||
|
|
||||||
async def _handle_streaming_request(
|
async def _handle_streaming_request(
|
||||||
@@ -196,122 +116,126 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
|||||||
raw_request: Request,
|
raw_request: Request,
|
||||||
) -> StreamingResponse:
|
) -> StreamingResponse:
|
||||||
"""Handle streaming completion request"""
|
"""Handle streaming completion request"""
|
||||||
created = int(time.time())
|
|
||||||
|
|
||||||
async def generate_stream_resp():
|
|
||||||
stream_buffers = {}
|
|
||||||
n_prev_tokens = {}
|
|
||||||
prompt_tokens = {}
|
|
||||||
completion_tokens = {}
|
|
||||||
cached_tokens = {}
|
|
||||||
|
|
||||||
try:
|
|
||||||
async for content in self.tokenizer_manager.generate_request(
|
|
||||||
adapted_request, raw_request
|
|
||||||
):
|
|
||||||
index = content.get("index", 0)
|
|
||||||
|
|
||||||
stream_buffer = stream_buffers.get(index, "")
|
|
||||||
n_prev_token = n_prev_tokens.get(index, 0)
|
|
||||||
|
|
||||||
text = content["text"]
|
|
||||||
prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
|
|
||||||
completion_tokens[index] = content["meta_info"]["completion_tokens"]
|
|
||||||
cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
|
|
||||||
|
|
||||||
# Handle echo for first chunk
|
|
||||||
if not stream_buffer: # The first chunk
|
|
||||||
if request.echo:
|
|
||||||
echo_text = self._get_echo_text(request, index)
|
|
||||||
text = echo_text + text
|
|
||||||
|
|
||||||
# Handle logprobs
|
|
||||||
logprobs = None
|
|
||||||
if request.logprobs is not None:
|
|
||||||
# The first chunk and echo is enabled.
|
|
||||||
if not stream_buffer and request.echo:
|
|
||||||
input_token_logprobs = content["meta_info"][
|
|
||||||
"input_token_logprobs"
|
|
||||||
]
|
|
||||||
input_top_logprobs = content["meta_info"][
|
|
||||||
"input_top_logprobs"
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
input_token_logprobs = None
|
|
||||||
input_top_logprobs = None
|
|
||||||
|
|
||||||
logprobs = to_openai_style_logprobs(
|
|
||||||
input_token_logprobs=input_token_logprobs,
|
|
||||||
input_top_logprobs=input_top_logprobs,
|
|
||||||
output_token_logprobs=content["meta_info"][
|
|
||||||
"output_token_logprobs"
|
|
||||||
][n_prev_token:],
|
|
||||||
output_top_logprobs=content["meta_info"][
|
|
||||||
"output_top_logprobs"
|
|
||||||
][n_prev_token:],
|
|
||||||
)
|
|
||||||
n_prev_token = len(
|
|
||||||
content["meta_info"]["output_token_logprobs"]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generate delta
|
|
||||||
delta = text[len(stream_buffer) :]
|
|
||||||
stream_buffer = stream_buffer + delta
|
|
||||||
finish_reason = content["meta_info"]["finish_reason"]
|
|
||||||
|
|
||||||
choice_data = CompletionResponseStreamChoice(
|
|
||||||
index=index,
|
|
||||||
text=delta,
|
|
||||||
logprobs=logprobs,
|
|
||||||
finish_reason=finish_reason["type"] if finish_reason else None,
|
|
||||||
matched_stop=(
|
|
||||||
finish_reason["matched"]
|
|
||||||
if finish_reason and "matched" in finish_reason
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
)
|
|
||||||
chunk = CompletionStreamResponse(
|
|
||||||
id=content["meta_info"]["id"],
|
|
||||||
created=created,
|
|
||||||
object="text_completion",
|
|
||||||
choices=[choice_data],
|
|
||||||
model=request.model,
|
|
||||||
)
|
|
||||||
|
|
||||||
stream_buffers[index] = stream_buffer
|
|
||||||
n_prev_tokens[index] = n_prev_token
|
|
||||||
|
|
||||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
|
||||||
|
|
||||||
# Handle final usage chunk
|
|
||||||
if request.stream_options and request.stream_options.include_usage:
|
|
||||||
usage = self._calculate_streaming_usage_base(
|
|
||||||
prompt_tokens, completion_tokens, cached_tokens, request.n
|
|
||||||
)
|
|
||||||
final_usage_chunk = CompletionStreamResponse(
|
|
||||||
id=content["meta_info"]["id"],
|
|
||||||
created=created,
|
|
||||||
choices=[],
|
|
||||||
model=request.model,
|
|
||||||
usage=usage,
|
|
||||||
)
|
|
||||||
final_usage_data = final_usage_chunk.model_dump_json(
|
|
||||||
exclude_none=True
|
|
||||||
)
|
|
||||||
yield f"data: {final_usage_data}\n\n"
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
error = self.create_streaming_error_response(str(e))
|
|
||||||
yield f"data: {error}\n\n"
|
|
||||||
|
|
||||||
yield "data: [DONE]\n\n"
|
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
generate_stream_resp(),
|
self._generate_completion_stream(adapted_request, request, raw_request),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
background=self.tokenizer_manager.create_abort_task(adapted_request),
|
background=self.tokenizer_manager.create_abort_task(adapted_request),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _generate_completion_stream(
|
||||||
|
self,
|
||||||
|
adapted_request: GenerateReqInput,
|
||||||
|
request: CompletionRequest,
|
||||||
|
raw_request: Request,
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
"""Generate streaming completion response"""
|
||||||
|
created = int(time.time())
|
||||||
|
|
||||||
|
# State tracking for streaming
|
||||||
|
stream_buffers = {}
|
||||||
|
n_prev_tokens = {}
|
||||||
|
|
||||||
|
# Usage tracking
|
||||||
|
prompt_tokens = {}
|
||||||
|
completion_tokens = {}
|
||||||
|
cached_tokens = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for content in self.tokenizer_manager.generate_request(
|
||||||
|
adapted_request, raw_request
|
||||||
|
):
|
||||||
|
index = content.get("index", 0)
|
||||||
|
|
||||||
|
text = content["text"]
|
||||||
|
prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
|
||||||
|
completion_tokens[index] = content["meta_info"]["completion_tokens"]
|
||||||
|
cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
|
||||||
|
|
||||||
|
stream_buffer = stream_buffers.get(index, "")
|
||||||
|
# Handle echo for first chunk
|
||||||
|
if not stream_buffer: # The first chunk
|
||||||
|
if request.echo:
|
||||||
|
echo_text = self._get_echo_text(request, index)
|
||||||
|
text = echo_text + text
|
||||||
|
|
||||||
|
# Handle logprobs
|
||||||
|
logprobs = None
|
||||||
|
if request.logprobs is not None:
|
||||||
|
# The first chunk and echo is enabled.
|
||||||
|
if not stream_buffer and request.echo:
|
||||||
|
input_token_logprobs = content["meta_info"][
|
||||||
|
"input_token_logprobs"
|
||||||
|
]
|
||||||
|
input_top_logprobs = content["meta_info"]["input_top_logprobs"]
|
||||||
|
else:
|
||||||
|
input_token_logprobs = None
|
||||||
|
input_top_logprobs = None
|
||||||
|
|
||||||
|
n_prev_token = n_prev_tokens.get(index, 0)
|
||||||
|
logprobs = to_openai_style_logprobs(
|
||||||
|
input_token_logprobs=input_token_logprobs,
|
||||||
|
input_top_logprobs=input_top_logprobs,
|
||||||
|
output_token_logprobs=content["meta_info"][
|
||||||
|
"output_token_logprobs"
|
||||||
|
][n_prev_token:],
|
||||||
|
output_top_logprobs=content["meta_info"]["output_top_logprobs"][
|
||||||
|
n_prev_token:
|
||||||
|
],
|
||||||
|
)
|
||||||
|
n_prev_tokens[index] = len(
|
||||||
|
content["meta_info"]["output_token_logprobs"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate delta
|
||||||
|
delta = text[len(stream_buffer) :]
|
||||||
|
stream_buffers[index] = stream_buffer + delta
|
||||||
|
finish_reason = content["meta_info"]["finish_reason"]
|
||||||
|
|
||||||
|
choice_data = CompletionResponseStreamChoice(
|
||||||
|
index=index,
|
||||||
|
text=delta,
|
||||||
|
logprobs=logprobs,
|
||||||
|
finish_reason=finish_reason["type"] if finish_reason else None,
|
||||||
|
matched_stop=(
|
||||||
|
finish_reason["matched"]
|
||||||
|
if finish_reason and "matched" in finish_reason
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
)
|
||||||
|
chunk = CompletionStreamResponse(
|
||||||
|
id=content["meta_info"]["id"],
|
||||||
|
created=created,
|
||||||
|
object="text_completion",
|
||||||
|
choices=[choice_data],
|
||||||
|
model=request.model,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||||
|
|
||||||
|
# Handle final usage chunk
|
||||||
|
if request.stream_options and request.stream_options.include_usage:
|
||||||
|
usage = self._calculate_streaming_usage_base(
|
||||||
|
prompt_tokens,
|
||||||
|
completion_tokens,
|
||||||
|
cached_tokens,
|
||||||
|
request.n,
|
||||||
|
)
|
||||||
|
final_usage_chunk = CompletionStreamResponse(
|
||||||
|
id=content["meta_info"]["id"],
|
||||||
|
created=created,
|
||||||
|
choices=[],
|
||||||
|
model=request.model,
|
||||||
|
usage=usage,
|
||||||
|
)
|
||||||
|
final_usage_data = final_usage_chunk.model_dump_json(exclude_none=True)
|
||||||
|
yield f"data: {final_usage_data}\n\n"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error = self.create_streaming_error_response(str(e))
|
||||||
|
yield f"data: {error}\n\n"
|
||||||
|
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
async def _handle_non_streaming_request(
|
async def _handle_non_streaming_request(
|
||||||
self,
|
self,
|
||||||
adapted_request: GenerateReqInput,
|
adapted_request: GenerateReqInput,
|
||||||
@@ -334,7 +258,6 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
|||||||
request,
|
request,
|
||||||
ret,
|
ret,
|
||||||
int(time.time()),
|
int(time.time()),
|
||||||
cache_report=self.tokenizer_manager.server_args.enable_cache_report,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
@@ -344,7 +267,6 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
|||||||
request: CompletionRequest,
|
request: CompletionRequest,
|
||||||
ret: List[Dict[str, Any]],
|
ret: List[Dict[str, Any]],
|
||||||
created: int,
|
created: int,
|
||||||
cache_report: bool = False,
|
|
||||||
) -> CompletionResponse:
|
) -> CompletionResponse:
|
||||||
"""Build completion response from generation results"""
|
"""Build completion response from generation results"""
|
||||||
choices = []
|
choices = []
|
||||||
@@ -352,7 +274,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
|||||||
|
|
||||||
# Prepare echo prompts if needed
|
# Prepare echo prompts if needed
|
||||||
echo_prompts = []
|
echo_prompts = []
|
||||||
if (not isinstance(request, list)) and request.echo:
|
if request.echo:
|
||||||
echo_prompts = self._prepare_echo_prompts(request)
|
echo_prompts = self._prepare_echo_prompts(request)
|
||||||
echo = True
|
echo = True
|
||||||
|
|
||||||
@@ -360,21 +282,13 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
|||||||
text = ret_item["text"]
|
text = ret_item["text"]
|
||||||
|
|
||||||
# Handle echo
|
# Handle echo
|
||||||
if isinstance(request, list) and request[idx].echo:
|
if echo:
|
||||||
echo = True
|
|
||||||
text = request[idx].prompt + text
|
|
||||||
elif echo and not isinstance(request, list):
|
|
||||||
prompt_index = idx // request.n
|
prompt_index = idx // request.n
|
||||||
text = echo_prompts[prompt_index] + text
|
text = echo_prompts[prompt_index] + text
|
||||||
|
|
||||||
# Handle logprobs
|
# Handle logprobs
|
||||||
logprobs = None
|
logprobs = None
|
||||||
if isinstance(request, list) and request[idx].logprobs is not None:
|
if request.logprobs is not None:
|
||||||
logprobs = True
|
|
||||||
elif (not isinstance(request, list)) and request.logprobs is not None:
|
|
||||||
logprobs = True
|
|
||||||
|
|
||||||
if logprobs:
|
|
||||||
if echo:
|
if echo:
|
||||||
input_token_logprobs = ret_item["meta_info"]["input_token_logprobs"]
|
input_token_logprobs = ret_item["meta_info"]["input_token_logprobs"]
|
||||||
input_top_logprobs = ret_item["meta_info"]["input_top_logprobs"]
|
input_top_logprobs = ret_item["meta_info"]["input_top_logprobs"]
|
||||||
@@ -407,6 +321,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
|||||||
choices.append(choice_data)
|
choices.append(choice_data)
|
||||||
|
|
||||||
# Calculate usage
|
# Calculate usage
|
||||||
|
cache_report = self.tokenizer_manager.server_args.enable_cache_report
|
||||||
usage = aggregate_token_usage(ret, request.n, cache_report)
|
usage = aggregate_token_usage(ret, request.n, cache_report)
|
||||||
|
|
||||||
return CompletionResponse(
|
return CompletionResponse(
|
||||||
|
|||||||
@@ -54,35 +54,25 @@ class OpenAIServingEmbedding(OpenAIServingBase):
|
|||||||
return f"All items in input list must be integers"
|
return f"All items in input list must be integers"
|
||||||
if item < 0:
|
if item < 0:
|
||||||
return f"Token ID at index {i} must be non-negative"
|
return f"Token ID at index {i} must be non-negative"
|
||||||
elif isinstance(first_item, list):
|
|
||||||
# List of lists (multiple token sequences)
|
|
||||||
for i, item in enumerate(input):
|
|
||||||
if not isinstance(item, list):
|
|
||||||
return f"Input at index {i} must be a list"
|
|
||||||
if not item:
|
|
||||||
return f"Input at index {i} cannot be empty"
|
|
||||||
if not all(isinstance(token, int) for token in item):
|
|
||||||
return f"Input at index {i} must contain only integers"
|
|
||||||
if any(token < 0 for token in item):
|
|
||||||
return f"Input at index {i} contains negative token IDs"
|
|
||||||
# Note: MultimodalEmbeddingInput validation would be handled by Pydantic
|
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _convert_to_internal_request(
|
def _convert_to_internal_request(
|
||||||
self,
|
self,
|
||||||
request: EmbeddingRequest,
|
request: EmbeddingRequest,
|
||||||
request_id: str,
|
) -> tuple[EmbeddingReqInput, EmbeddingRequest]:
|
||||||
) -> tuple[EmbeddingReqInput, Union[EmbeddingRequest, List[EmbeddingRequest]]]:
|
|
||||||
"""Convert OpenAI embedding request to internal format"""
|
"""Convert OpenAI embedding request to internal format"""
|
||||||
prompt = request.input
|
prompt = request.input
|
||||||
|
|
||||||
if isinstance(prompt, str):
|
if isinstance(prompt, str):
|
||||||
# Single string input
|
# Single string input
|
||||||
prompt_kwargs = {"text": prompt}
|
prompt_kwargs = {"text": prompt}
|
||||||
elif isinstance(prompt, list):
|
elif isinstance(prompt, list):
|
||||||
if len(prompt) > 0 and isinstance(prompt[0], str):
|
if len(prompt) > 0 and isinstance(prompt[0], str):
|
||||||
# List of strings
|
# List of strings - if it's a single string in a list, treat as single string
|
||||||
prompt_kwargs = {"text": prompt}
|
if len(prompt) == 1:
|
||||||
|
prompt_kwargs = {"text": prompt[0]}
|
||||||
|
else:
|
||||||
|
prompt_kwargs = {"text": prompt}
|
||||||
elif len(prompt) > 0 and isinstance(prompt[0], MultimodalEmbeddingInput):
|
elif len(prompt) > 0 and isinstance(prompt[0], MultimodalEmbeddingInput):
|
||||||
# Handle multimodal embedding inputs
|
# Handle multimodal embedding inputs
|
||||||
texts = []
|
texts = []
|
||||||
@@ -94,7 +84,6 @@ class OpenAIServingEmbedding(OpenAIServingBase):
|
|||||||
|
|
||||||
generate_prompts = []
|
generate_prompts = []
|
||||||
# Check if we have a chat template for multimodal embeddings
|
# Check if we have a chat template for multimodal embeddings
|
||||||
# This would need to be passed in from the server configuration
|
|
||||||
chat_template_name = getattr(
|
chat_template_name = getattr(
|
||||||
self.tokenizer_manager, "chat_template_name", None
|
self.tokenizer_manager, "chat_template_name", None
|
||||||
)
|
)
|
||||||
@@ -121,6 +110,7 @@ class OpenAIServingEmbedding(OpenAIServingBase):
|
|||||||
else:
|
else:
|
||||||
# Other types (should not happen but handle gracefully)
|
# Other types (should not happen but handle gracefully)
|
||||||
prompt_kwargs = {"input_ids": prompt}
|
prompt_kwargs = {"input_ids": prompt}
|
||||||
|
|
||||||
adapted_request = EmbeddingReqInput(
|
adapted_request = EmbeddingReqInput(
|
||||||
**prompt_kwargs,
|
**prompt_kwargs,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -104,52 +104,50 @@ class ServingChatTestCase(unittest.TestCase):
|
|||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
adapted, processed = self.chat._convert_to_internal_request(
|
adapted, processed = self.chat._convert_to_internal_request(self.basic_req)
|
||||||
[self.basic_req], ["rid"]
|
|
||||||
)
|
|
||||||
self.assertIsInstance(adapted, GenerateReqInput)
|
self.assertIsInstance(adapted, GenerateReqInput)
|
||||||
self.assertFalse(adapted.stream)
|
self.assertFalse(adapted.stream)
|
||||||
self.assertEqual(processed, self.basic_req)
|
self.assertEqual(processed, self.basic_req)
|
||||||
|
|
||||||
# ------------- tool-call branch -------------
|
# # ------------- tool-call branch -------------
|
||||||
def test_tool_call_request_conversion(self):
|
# def test_tool_call_request_conversion(self):
|
||||||
req = ChatCompletionRequest(
|
# req = ChatCompletionRequest(
|
||||||
model="x",
|
# model="x",
|
||||||
messages=[{"role": "user", "content": "Weather?"}],
|
# messages=[{"role": "user", "content": "Weather?"}],
|
||||||
tools=[
|
# tools=[
|
||||||
{
|
# {
|
||||||
"type": "function",
|
# "type": "function",
|
||||||
"function": {
|
# "function": {
|
||||||
"name": "get_weather",
|
# "name": "get_weather",
|
||||||
"parameters": {"type": "object", "properties": {}},
|
# "parameters": {"type": "object", "properties": {}},
|
||||||
},
|
# },
|
||||||
}
|
# }
|
||||||
],
|
# ],
|
||||||
tool_choice="auto",
|
# tool_choice="auto",
|
||||||
)
|
# )
|
||||||
|
|
||||||
with patch.object(
|
# with patch.object(
|
||||||
self.chat,
|
# self.chat,
|
||||||
"_process_messages",
|
# "_process_messages",
|
||||||
return_value=("Prompt", [1, 2, 3], None, None, [], ["</s>"], None),
|
# return_value=("Prompt", [1, 2, 3], None, None, [], ["</s>"], None),
|
||||||
):
|
# ):
|
||||||
adapted, _ = self.chat._convert_to_internal_request([req], ["rid"])
|
# adapted, _ = self.chat._convert_to_internal_request(req, "rid")
|
||||||
self.assertEqual(adapted.rid, "rid")
|
# self.assertEqual(adapted.rid, "rid")
|
||||||
|
|
||||||
def test_tool_choice_none(self):
|
# def test_tool_choice_none(self):
|
||||||
req = ChatCompletionRequest(
|
# req = ChatCompletionRequest(
|
||||||
model="x",
|
# model="x",
|
||||||
messages=[{"role": "user", "content": "Hi"}],
|
# messages=[{"role": "user", "content": "Hi"}],
|
||||||
tools=[{"type": "function", "function": {"name": "noop"}}],
|
# tools=[{"type": "function", "function": {"name": "noop"}}],
|
||||||
tool_choice="none",
|
# tool_choice="none",
|
||||||
)
|
# )
|
||||||
with patch.object(
|
# with patch.object(
|
||||||
self.chat,
|
# self.chat,
|
||||||
"_process_messages",
|
# "_process_messages",
|
||||||
return_value=("Prompt", [1, 2, 3], None, None, [], ["</s>"], None),
|
# return_value=("Prompt", [1, 2, 3], None, None, [], ["</s>"], None),
|
||||||
):
|
# ):
|
||||||
adapted, _ = self.chat._convert_to_internal_request([req], ["rid"])
|
# adapted, _ = self.chat._convert_to_internal_request(req, "rid")
|
||||||
self.assertEqual(adapted.rid, "rid")
|
# self.assertEqual(adapted.rid, "rid")
|
||||||
|
|
||||||
# ------------- multimodal branch -------------
|
# ------------- multimodal branch -------------
|
||||||
def test_multimodal_request_with_images(self):
|
def test_multimodal_request_with_images(self):
|
||||||
|
|||||||
@@ -36,12 +36,12 @@ class ServingCompletionTestCase(unittest.TestCase):
|
|||||||
# ---------- prompt-handling ----------
|
# ---------- prompt-handling ----------
|
||||||
def test_single_string_prompt(self):
|
def test_single_string_prompt(self):
|
||||||
req = CompletionRequest(model="x", prompt="Hello world", max_tokens=100)
|
req = CompletionRequest(model="x", prompt="Hello world", max_tokens=100)
|
||||||
internal, _ = self.sc._convert_to_internal_request([req], ["id"])
|
internal, _ = self.sc._convert_to_internal_request(req)
|
||||||
self.assertEqual(internal.text, "Hello world")
|
self.assertEqual(internal.text, "Hello world")
|
||||||
|
|
||||||
def test_single_token_ids_prompt(self):
|
def test_single_token_ids_prompt(self):
|
||||||
req = CompletionRequest(model="x", prompt=[1, 2, 3, 4], max_tokens=100)
|
req = CompletionRequest(model="x", prompt=[1, 2, 3, 4], max_tokens=100)
|
||||||
internal, _ = self.sc._convert_to_internal_request([req], ["id"])
|
internal, _ = self.sc._convert_to_internal_request(req)
|
||||||
self.assertEqual(internal.input_ids, [1, 2, 3, 4])
|
self.assertEqual(internal.input_ids, [1, 2, 3, 4])
|
||||||
|
|
||||||
def test_completion_template_handling(self):
|
def test_completion_template_handling(self):
|
||||||
@@ -55,7 +55,7 @@ class ServingCompletionTestCase(unittest.TestCase):
|
|||||||
"sglang.srt.entrypoints.openai.serving_completions.generate_completion_prompt_from_request",
|
"sglang.srt.entrypoints.openai.serving_completions.generate_completion_prompt_from_request",
|
||||||
return_value="processed_prompt",
|
return_value="processed_prompt",
|
||||||
):
|
):
|
||||||
internal, _ = self.sc._convert_to_internal_request([req], ["id"])
|
internal, _ = self.sc._convert_to_internal_request(req)
|
||||||
self.assertEqual(internal.text, "processed_prompt")
|
self.assertEqual(internal.text, "processed_prompt")
|
||||||
|
|
||||||
# ---------- echo-handling ----------
|
# ---------- echo-handling ----------
|
||||||
|
|||||||
@@ -94,50 +94,42 @@ class ServingEmbeddingTestCase(unittest.TestCase):
|
|||||||
def test_convert_single_string_request(self):
|
def test_convert_single_string_request(self):
|
||||||
"""Test converting single string request to internal format."""
|
"""Test converting single string request to internal format."""
|
||||||
adapted_request, processed_request = (
|
adapted_request, processed_request = (
|
||||||
self.serving_embedding._convert_to_internal_request(
|
self.serving_embedding._convert_to_internal_request(self.basic_req)
|
||||||
self.basic_req, "test-id"
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertIsInstance(adapted_request, EmbeddingReqInput)
|
self.assertIsInstance(adapted_request, EmbeddingReqInput)
|
||||||
self.assertEqual(adapted_request.text, "Hello, how are you?")
|
self.assertEqual(adapted_request.text, "Hello, how are you?")
|
||||||
self.assertEqual(adapted_request.rid, None)
|
# self.assertEqual(adapted_request.rid, "test-id")
|
||||||
self.assertEqual(processed_request, self.basic_req)
|
self.assertEqual(processed_request, self.basic_req)
|
||||||
|
|
||||||
def test_convert_list_string_request(self):
|
def test_convert_list_string_request(self):
|
||||||
"""Test converting list of strings request to internal format."""
|
"""Test converting list of strings request to internal format."""
|
||||||
adapted_request, processed_request = (
|
adapted_request, processed_request = (
|
||||||
self.serving_embedding._convert_to_internal_request(
|
self.serving_embedding._convert_to_internal_request(self.list_req)
|
||||||
self.list_req, "test-id"
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertIsInstance(adapted_request, EmbeddingReqInput)
|
self.assertIsInstance(adapted_request, EmbeddingReqInput)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
adapted_request.text, ["Hello, how are you?", "I am fine, thank you!"]
|
adapted_request.text, ["Hello, how are you?", "I am fine, thank you!"]
|
||||||
)
|
)
|
||||||
self.assertEqual(adapted_request.rid, None)
|
# self.assertEqual(adapted_request.rid, "test-id")
|
||||||
self.assertEqual(processed_request, self.list_req)
|
self.assertEqual(processed_request, self.list_req)
|
||||||
|
|
||||||
def test_convert_token_ids_request(self):
|
def test_convert_token_ids_request(self):
|
||||||
"""Test converting token IDs request to internal format."""
|
"""Test converting token IDs request to internal format."""
|
||||||
adapted_request, processed_request = (
|
adapted_request, processed_request = (
|
||||||
self.serving_embedding._convert_to_internal_request(
|
self.serving_embedding._convert_to_internal_request(self.token_ids_req)
|
||||||
self.token_ids_req, "test-id"
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertIsInstance(adapted_request, EmbeddingReqInput)
|
self.assertIsInstance(adapted_request, EmbeddingReqInput)
|
||||||
self.assertEqual(adapted_request.input_ids, [1, 2, 3, 4, 5])
|
self.assertEqual(adapted_request.input_ids, [1, 2, 3, 4, 5])
|
||||||
self.assertEqual(adapted_request.rid, None)
|
# self.assertEqual(adapted_request.rid, "test-id")
|
||||||
self.assertEqual(processed_request, self.token_ids_req)
|
self.assertEqual(processed_request, self.token_ids_req)
|
||||||
|
|
||||||
def test_convert_multimodal_request(self):
|
def test_convert_multimodal_request(self):
|
||||||
"""Test converting multimodal request to internal format."""
|
"""Test converting multimodal request to internal format."""
|
||||||
adapted_request, processed_request = (
|
adapted_request, processed_request = (
|
||||||
self.serving_embedding._convert_to_internal_request(
|
self.serving_embedding._convert_to_internal_request(self.multimodal_req)
|
||||||
self.multimodal_req, "test-id"
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertIsInstance(adapted_request, EmbeddingReqInput)
|
self.assertIsInstance(adapted_request, EmbeddingReqInput)
|
||||||
@@ -147,7 +139,7 @@ class ServingEmbeddingTestCase(unittest.TestCase):
|
|||||||
self.assertIn("World", adapted_request.text)
|
self.assertIn("World", adapted_request.text)
|
||||||
self.assertEqual(adapted_request.image_data[0], "base64_image_data")
|
self.assertEqual(adapted_request.image_data[0], "base64_image_data")
|
||||||
self.assertIsNone(adapted_request.image_data[1])
|
self.assertIsNone(adapted_request.image_data[1])
|
||||||
self.assertEqual(adapted_request.rid, None)
|
# self.assertEqual(adapted_request.rid, "test-id")
|
||||||
|
|
||||||
def test_build_single_embedding_response(self):
|
def test_build_single_embedding_response(self):
|
||||||
"""Test building response for single embedding."""
|
"""Test building response for single embedding."""
|
||||||
@@ -194,72 +186,86 @@ class ServingEmbeddingTestCase(unittest.TestCase):
|
|||||||
self.assertEqual(response.usage.prompt_tokens, 7) # 3 + 4
|
self.assertEqual(response.usage.prompt_tokens, 7) # 3 + 4
|
||||||
self.assertEqual(response.usage.total_tokens, 7)
|
self.assertEqual(response.usage.total_tokens, 7)
|
||||||
|
|
||||||
async def test_handle_request_success(self):
|
def test_handle_request_success(self):
|
||||||
"""Test successful embedding request handling."""
|
"""Test successful embedding request handling."""
|
||||||
|
|
||||||
# Mock the generate_request to return expected data
|
async def run_test():
|
||||||
async def mock_generate():
|
# Mock the generate_request to return expected data
|
||||||
yield {
|
async def mock_generate():
|
||||||
"embedding": [0.1, 0.2, 0.3, 0.4, 0.5],
|
yield {
|
||||||
"meta_info": {"prompt_tokens": 5},
|
"embedding": [0.1, 0.2, 0.3, 0.4, 0.5],
|
||||||
}
|
"meta_info": {"prompt_tokens": 5},
|
||||||
|
}
|
||||||
|
|
||||||
self.serving_embedding.tokenizer_manager.generate_request = Mock(
|
self.serving_embedding.tokenizer_manager.generate_request = Mock(
|
||||||
return_value=mock_generate()
|
return_value=mock_generate()
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await self.serving_embedding.handle_request(
|
response = await self.serving_embedding.handle_request(
|
||||||
self.basic_req, self.request
|
self.basic_req, self.request
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertIsInstance(response, EmbeddingResponse)
|
self.assertIsInstance(response, EmbeddingResponse)
|
||||||
self.assertEqual(len(response.data), 1)
|
self.assertEqual(len(response.data), 1)
|
||||||
self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3, 0.4, 0.5])
|
self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3, 0.4, 0.5])
|
||||||
|
|
||||||
async def test_handle_request_validation_error(self):
|
asyncio.run(run_test())
|
||||||
|
|
||||||
|
def test_handle_request_validation_error(self):
|
||||||
"""Test handling request with validation error."""
|
"""Test handling request with validation error."""
|
||||||
invalid_request = EmbeddingRequest(model="test-model", input="")
|
|
||||||
|
|
||||||
response = await self.serving_embedding.handle_request(
|
async def run_test():
|
||||||
invalid_request, self.request
|
invalid_request = EmbeddingRequest(model="test-model", input="")
|
||||||
)
|
|
||||||
|
|
||||||
self.assertIsInstance(response, ORJSONResponse)
|
response = await self.serving_embedding.handle_request(
|
||||||
self.assertEqual(response.status_code, 400)
|
invalid_request, self.request
|
||||||
|
)
|
||||||
|
|
||||||
async def test_handle_request_generation_error(self):
|
self.assertIsInstance(response, ORJSONResponse)
|
||||||
|
self.assertEqual(response.status_code, 400)
|
||||||
|
|
||||||
|
asyncio.run(run_test())
|
||||||
|
|
||||||
|
def test_handle_request_generation_error(self):
|
||||||
"""Test handling request with generation error."""
|
"""Test handling request with generation error."""
|
||||||
|
|
||||||
# Mock generate_request to raise an error
|
async def run_test():
|
||||||
async def mock_generate_error():
|
# Mock generate_request to raise an error
|
||||||
raise ValueError("Generation failed")
|
async def mock_generate_error():
|
||||||
yield # This won't be reached but needed for async generator
|
raise ValueError("Generation failed")
|
||||||
|
yield # This won't be reached but needed for async generator
|
||||||
|
|
||||||
self.serving_embedding.tokenizer_manager.generate_request = Mock(
|
self.serving_embedding.tokenizer_manager.generate_request = Mock(
|
||||||
return_value=mock_generate_error()
|
return_value=mock_generate_error()
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await self.serving_embedding.handle_request(
|
|
||||||
self.basic_req, self.request
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertIsInstance(response, ORJSONResponse)
|
|
||||||
self.assertEqual(response.status_code, 400)
|
|
||||||
|
|
||||||
async def test_handle_request_internal_error(self):
|
|
||||||
"""Test handling request with internal server error."""
|
|
||||||
# Mock _convert_to_internal_request to raise an exception
|
|
||||||
with patch.object(
|
|
||||||
self.serving_embedding,
|
|
||||||
"_convert_to_internal_request",
|
|
||||||
side_effect=Exception("Internal error"),
|
|
||||||
):
|
|
||||||
response = await self.serving_embedding.handle_request(
|
response = await self.serving_embedding.handle_request(
|
||||||
self.basic_req, self.request
|
self.basic_req, self.request
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertIsInstance(response, ORJSONResponse)
|
self.assertIsInstance(response, ORJSONResponse)
|
||||||
self.assertEqual(response.status_code, 500)
|
self.assertEqual(response.status_code, 400)
|
||||||
|
|
||||||
|
asyncio.run(run_test())
|
||||||
|
|
||||||
|
def test_handle_request_internal_error(self):
|
||||||
|
"""Test handling request with internal server error."""
|
||||||
|
|
||||||
|
async def run_test():
|
||||||
|
# Mock _convert_to_internal_request to raise an exception
|
||||||
|
with patch.object(
|
||||||
|
self.serving_embedding,
|
||||||
|
"_convert_to_internal_request",
|
||||||
|
side_effect=Exception("Internal error"),
|
||||||
|
):
|
||||||
|
response = await self.serving_embedding.handle_request(
|
||||||
|
self.basic_req, self.request
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertIsInstance(response, ORJSONResponse)
|
||||||
|
self.assertEqual(response.status_code, 500)
|
||||||
|
|
||||||
|
asyncio.run(run_test())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user