Abort disconnected requests (#457)

This commit is contained in:
Lianmin Zheng
2024-05-20 18:41:21 -07:00
committed by GitHub
parent 3e684be7a3
commit 8dbdc018a3
8 changed files with 202 additions and 132 deletions

View File

@@ -1,10 +1,12 @@
"""Conversion between OpenAI APIs and native SRT APIs"""
import asyncio
import json
import os
from http import HTTPStatus
from fastapi import HTTPException, Request
from fastapi.responses import StreamingResponse
from fastapi import Request
from fastapi.responses import StreamingResponse, JSONResponse
from sglang.srt.conversation import (
Conversation,
@@ -27,14 +29,36 @@ from sglang.srt.openai_protocol import (
CompletionResponseStreamChoice,
CompletionStreamResponse,
DeltaMessage,
ErrorResponse,
LogProbs,
UsageInfo,
)
from sglang.srt.utils import jsonify_pydantic_model
chat_template_name = None
def create_error_response(
message: str,
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST):
error = ErrorResponse(message=message,
type=err_type,
code=status_code.value)
return JSONResponse(content=error.model_dump(),
status_code=error.code)
def create_streaming_error_response(
message: str,
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str:
error = ErrorResponse(message=message,
type=err_type,
code=status_code.value)
json_str = json.dumps({"error": error.model_dump()})
return json_str
def load_chat_template_for_openai_api(chat_template_arg):
global chat_template_name
@@ -74,8 +98,8 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
request_json = await raw_request.json()
request = CompletionRequest(**request_json)
# TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid.
assert request.n == 1
if request.n != 1:
return create_error_response("n != 1 is not supported")
adapted_request = GenerateReqInput(
text=request.prompt,
@@ -93,79 +117,88 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
return_text_in_logprobs=True,
stream=request.stream,
)
adapted_request.post_init()
if adapted_request.stream:
async def generate_stream_resp():
stream_buffer = ""
n_prev_token = 0
async for content in tokenizer_manager.generate_request(adapted_request):
text = content["text"]
prompt_tokens = content["meta_info"]["prompt_tokens"]
completion_tokens = content["meta_info"]["completion_tokens"]
try:
async for content in tokenizer_manager.generate_request(
adapted_request, raw_request):
text = content["text"]
prompt_tokens = content["meta_info"]["prompt_tokens"]
completion_tokens = content["meta_info"]["completion_tokens"]
if not stream_buffer: # The first chunk
if request.echo:
# Prepend prompt in response text.
text = request.prompt + text
if not stream_buffer: # The first chunk
if request.echo:
# Prepend prompt in response text.
text = request.prompt + text
if request.logprobs:
# The first chunk and echo is enabled.
if not stream_buffer and request.echo:
prefill_token_logprobs = content["meta_info"][
"prefill_token_logprobs"
]
prefill_top_logprobs = content["meta_info"][
"prefill_top_logprobs"
]
if request.logprobs:
# The first chunk and echo is enabled.
if not stream_buffer and request.echo:
prefill_token_logprobs = content["meta_info"][
"prefill_token_logprobs"
]
prefill_top_logprobs = content["meta_info"][
"prefill_top_logprobs"
]
else:
prefill_token_logprobs = None
prefill_top_logprobs = None
logprobs = to_openai_style_logprobs(
prefill_token_logprobs=prefill_token_logprobs,
prefill_top_logprobs=prefill_top_logprobs,
decode_token_logprobs=content["meta_info"][
"decode_token_logprobs"
][n_prev_token:],
decode_top_logprobs=content["meta_info"]["decode_top_logprobs"][
n_prev_token:
],
)
n_prev_token = len(content["meta_info"]["decode_token_logprobs"])
else:
prefill_token_logprobs = None
prefill_top_logprobs = None
logprobs = None
logprobs = to_openai_style_logprobs(
prefill_token_logprobs=prefill_token_logprobs,
prefill_top_logprobs=prefill_top_logprobs,
decode_token_logprobs=content["meta_info"][
"decode_token_logprobs"
][n_prev_token:],
decode_top_logprobs=content["meta_info"]["decode_top_logprobs"][
n_prev_token:
],
delta = text[len(stream_buffer) :]
stream_buffer = content["text"]
choice_data = CompletionResponseStreamChoice(
index=0,
text=delta,
logprobs=logprobs,
finish_reason=content["meta_info"]["finish_reason"],
)
n_prev_token = len(content["meta_info"]["decode_token_logprobs"])
else:
logprobs = None
delta = text[len(stream_buffer) :]
stream_buffer = content["text"]
choice_data = CompletionResponseStreamChoice(
index=0,
text=delta,
logprobs=logprobs,
finish_reason=content["meta_info"]["finish_reason"],
)
chunk = CompletionStreamResponse(
id=content["meta_info"]["id"],
object="text_completion",
choices=[choice_data],
model=request.model,
usage=UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
chunk = CompletionStreamResponse(
id=content["meta_info"]["id"],
object="text_completion",
choices=[choice_data],
model=request.model,
usage=UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
yield f"data: {chunk.model_dump_json()}\n\n"
except ValueError as e:
error = create_streaming_error_response(str(e))
yield f"data: {error}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(generate_stream_resp(), media_type="text/event-stream")
return StreamingResponse(generate_stream_resp(), media_type="text/event-stream",
background=tokenizer_manager.create_abort_task(adapted_request))
# Non-streaming response.
ret = await tokenizer_manager.generate_request(adapted_request).__anext__()
ret = ret[0] if isinstance(ret, list) else ret
try:
ret = await tokenizer_manager.generate_request(
adapted_request, raw_request).__anext__()
except ValueError as e:
return create_error_response(str(e))
ret = ret[0] if isinstance(ret, list) else ret
prompt_tokens = ret["meta_info"]["prompt_tokens"]
completion_tokens = ret["meta_info"]["completion_tokens"]
text = ret["text"]
@@ -212,8 +245,8 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
request_json = await raw_request.json()
request = ChatCompletionRequest(**request_json)
# TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid.
assert request.n == 1
if request.n != 1:
return create_error_response("n != 1 is not supported")
# Prep the data needed for the underlying GenerateReqInput:
# - prompt: The full prompt string.
@@ -258,7 +291,6 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
},
stream=request.stream,
)
adapted_request.post_init()
if adapted_request.stream:
@@ -266,13 +298,29 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
is_first = True
stream_buffer = ""
async for content in tokenizer_manager.generate_request(adapted_request):
if is_first:
# First chunk with role
is_first = False
try:
async for content in tokenizer_manager.generate_request(adapted_request, raw_request):
if is_first:
# First chunk with role
is_first = False
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant"),
finish_reason=content["meta_info"]["finish_reason"],
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
choices=[choice_data],
model=request.model,
)
yield f"data: {chunk.model_dump_json()}\n\n"
text = content["text"]
delta = text[len(stream_buffer) :]
stream_buffer = text
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant"),
delta=DeltaMessage(content=delta),
finish_reason=content["meta_info"]["finish_reason"],
)
chunk = ChatCompletionStreamResponse(
@@ -280,28 +328,22 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
choices=[choice_data],
model=request.model,
)
yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
text = content["text"]
delta = text[len(stream_buffer) :]
stream_buffer = text
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(content=delta),
finish_reason=content["meta_info"]["finish_reason"],
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
choices=[choice_data],
model=request.model,
)
yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
yield f"data: {chunk.model_dump_json()}\n\n"
except ValueError as e:
error = create_streaming_error_response(str(e))
yield f"data: {error}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(generate_stream_resp(), media_type="text/event-stream")
return StreamingResponse(generate_stream_resp(), media_type="text/event-stream",
background=tokenizer_manager.create_abort_task(adapted_request))
# Non-streaming response.
ret = await tokenizer_manager.generate_request(adapted_request).__anext__()
try:
ret = await tokenizer_manager.generate_request(
adapted_request, raw_request).__anext__()
except ValueError as e:
return create_error_response(str(e))
prompt_tokens = ret["meta_info"]["prompt_tokens"]
completion_tokens = ret["meta_info"]["completion_tokens"]
choice_data = ChatCompletionResponseChoice(