Abort disconnected requests (#457)
This commit is contained in:
@@ -580,8 +580,8 @@ class StreamExecutor:
|
|||||||
def _execute_role_end(self, expr: SglRoleEnd):
|
def _execute_role_end(self, expr: SglRoleEnd):
|
||||||
if (
|
if (
|
||||||
self.cur_role == "assistant"
|
self.cur_role == "assistant"
|
||||||
and self.backend.is_chat_model
|
|
||||||
and self.api_num_spec_tokens is not None
|
and self.api_num_spec_tokens is not None
|
||||||
|
and self.backend.is_chat_model
|
||||||
):
|
):
|
||||||
# Execute the stored lazy generation calls
|
# Execute the stored lazy generation calls
|
||||||
self.backend.role_end_generate(self)
|
self.backend.role_end_generate(self)
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ class FinishReason(IntEnum):
|
|||||||
EOS_TOKEN = auto()
|
EOS_TOKEN = auto()
|
||||||
LENGTH = auto()
|
LENGTH = auto()
|
||||||
STOP_STR = auto()
|
STOP_STR = auto()
|
||||||
|
ABORT = auto()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def to_str(reason):
|
def to_str(reason):
|
||||||
@@ -28,6 +29,8 @@ class FinishReason(IntEnum):
|
|||||||
return "length"
|
return "length"
|
||||||
elif reason == FinishReason.STOP_STR:
|
elif reason == FinishReason.STOP_STR:
|
||||||
return "stop"
|
return "stop"
|
||||||
|
elif reason == FinishReason.ABORT:
|
||||||
|
return "abort"
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -86,6 +89,35 @@ class Req:
|
|||||||
def max_new_tokens(self):
|
def max_new_tokens(self):
|
||||||
return self.sampling_params.max_new_tokens
|
return self.sampling_params.max_new_tokens
|
||||||
|
|
||||||
|
def check_finished(self):
|
||||||
|
if self.finished:
|
||||||
|
return
|
||||||
|
|
||||||
|
if len(self.output_ids) >= self.sampling_params.max_new_tokens:
|
||||||
|
self.finished = True
|
||||||
|
self.finish_reason = FinishReason.LENGTH
|
||||||
|
return
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.output_ids[-1] == self.tokenizer.eos_token_id
|
||||||
|
and self.sampling_params.ignore_eos == False
|
||||||
|
):
|
||||||
|
self.finished = True
|
||||||
|
self.finish_reason = FinishReason.EOS_TOKEN
|
||||||
|
return
|
||||||
|
|
||||||
|
if len(self.sampling_params.stop_strs) > 0:
|
||||||
|
tail_str = self.tokenizer.decode(
|
||||||
|
self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
|
||||||
|
)
|
||||||
|
|
||||||
|
for stop_str in self.sampling_params.stop_strs:
|
||||||
|
if stop_str in tail_str:
|
||||||
|
self.finished = True
|
||||||
|
self.finish_reason = FinishReason.STOP_STR
|
||||||
|
self.hit_stop_str = stop_str
|
||||||
|
return
|
||||||
|
|
||||||
def jump_forward_and_retokenize(self, jump_forward_str, next_state):
|
def jump_forward_and_retokenize(self, jump_forward_str, next_state):
|
||||||
old_output_str = self.tokenizer.decode(self.output_ids)
|
old_output_str = self.tokenizer.decode(self.output_ids)
|
||||||
# FIXME: This logic does not really solve the problem of determining whether
|
# FIXME: This logic does not really solve the problem of determining whether
|
||||||
@@ -132,35 +164,6 @@ class Req:
|
|||||||
# print(f"Output and jump forward str:\n{self.output_and_jump_forward_str}")
|
# print(f"Output and jump forward str:\n{self.output_and_jump_forward_str}")
|
||||||
# print("*" * 100)
|
# print("*" * 100)
|
||||||
|
|
||||||
def check_finished(self):
|
|
||||||
if self.finished:
|
|
||||||
return
|
|
||||||
|
|
||||||
if len(self.output_ids) >= self.sampling_params.max_new_tokens:
|
|
||||||
self.finished = True
|
|
||||||
self.finish_reason = FinishReason.LENGTH
|
|
||||||
return
|
|
||||||
|
|
||||||
if (
|
|
||||||
self.output_ids[-1] == self.tokenizer.eos_token_id
|
|
||||||
and self.sampling_params.ignore_eos == False
|
|
||||||
):
|
|
||||||
self.finished = True
|
|
||||||
self.finish_reason = FinishReason.EOS_TOKEN
|
|
||||||
return
|
|
||||||
|
|
||||||
if len(self.sampling_params.stop_strs) > 0:
|
|
||||||
tail_str = self.tokenizer.decode(
|
|
||||||
self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
|
|
||||||
)
|
|
||||||
|
|
||||||
for stop_str in self.sampling_params.stop_strs:
|
|
||||||
if stop_str in tail_str:
|
|
||||||
self.finished = True
|
|
||||||
self.finish_reason = FinishReason.STOP_STR
|
|
||||||
self.hit_stop_str = stop_str
|
|
||||||
return
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"rid(n={self.rid}, " f"input_ids={self.input_ids}, "
|
return f"rid(n={self.rid}, " f"input_ids={self.input_ids}, "
|
||||||
|
|
||||||
|
|||||||
@@ -679,6 +679,7 @@ class ModelRpcServer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def abort_request(self, recv_req):
|
def abort_request(self, recv_req):
|
||||||
|
# Delete requests in the waiting queue
|
||||||
to_del = None
|
to_del = None
|
||||||
for i, req in enumerate(self.forward_queue):
|
for i, req in enumerate(self.forward_queue):
|
||||||
if req.rid == recv_req.rid:
|
if req.rid == recv_req.rid:
|
||||||
@@ -688,6 +689,14 @@ class ModelRpcServer:
|
|||||||
if to_del is not None:
|
if to_del is not None:
|
||||||
del self.forward_queue[to_del]
|
del self.forward_queue[to_del]
|
||||||
|
|
||||||
|
# Delete requests in the running batch
|
||||||
|
if self.running_batch:
|
||||||
|
for req in self.running_batch.reqs:
|
||||||
|
if req.rid == recv_req.rid:
|
||||||
|
req.finished = True
|
||||||
|
req.finish_reason = FinishReason.ABORT
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
class ModelRpcService(rpyc.Service):
|
class ModelRpcService(rpyc.Service):
|
||||||
exposed_ModelRpcServer = ModelRpcServer
|
exposed_ModelRpcServer = ModelRpcServer
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import transformers
|
|||||||
import uvloop
|
import uvloop
|
||||||
import zmq
|
import zmq
|
||||||
import zmq.asyncio
|
import zmq.asyncio
|
||||||
|
from fastapi import BackgroundTasks
|
||||||
|
|
||||||
from sglang.srt.hf_transformers_utils import (
|
from sglang.srt.hf_transformers_utils import (
|
||||||
get_config,
|
get_config,
|
||||||
@@ -165,7 +166,7 @@ class TokenizerManager:
|
|||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(event.wait(), timeout=5)
|
await asyncio.wait_for(event.wait(), timeout=4)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
if request is not None and await request.is_disconnected():
|
if request is not None and await request.is_disconnected():
|
||||||
self.abort_request(rid)
|
self.abort_request(rid)
|
||||||
@@ -243,7 +244,7 @@ class TokenizerManager:
|
|||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(state.event.wait(), timeout=5)
|
await asyncio.wait_for(state.event.wait(), timeout=4)
|
||||||
break
|
break
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
if request is not None and await request.is_disconnected():
|
if request is not None and await request.is_disconnected():
|
||||||
@@ -270,10 +271,26 @@ class TokenizerManager:
|
|||||||
self.send_to_router.send_pyobj(req)
|
self.send_to_router.send_pyobj(req)
|
||||||
|
|
||||||
def abort_request(self, rid):
|
def abort_request(self, rid):
|
||||||
|
if rid not in self.rid_to_state:
|
||||||
|
return
|
||||||
del self.rid_to_state[rid]
|
del self.rid_to_state[rid]
|
||||||
req = AbortReq(rid)
|
req = AbortReq(rid)
|
||||||
self.send_to_router.send_pyobj(req)
|
self.send_to_router.send_pyobj(req)
|
||||||
|
|
||||||
|
def create_abort_task(self, obj):
|
||||||
|
# Abort the request if the client is disconnected.
|
||||||
|
async def abort_request():
|
||||||
|
await asyncio.sleep(3)
|
||||||
|
if obj.is_single:
|
||||||
|
self.abort_request(obj.rid)
|
||||||
|
else:
|
||||||
|
for rid in obj.rids:
|
||||||
|
self.abort_request(rid)
|
||||||
|
|
||||||
|
background_tasks = BackgroundTasks()
|
||||||
|
background_tasks.add_task(abort_request)
|
||||||
|
return background_tasks
|
||||||
|
|
||||||
def create_handle_loop(self):
|
def create_handle_loop(self):
|
||||||
self.to_create_loop = False
|
self.to_create_loop = False
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
"""Conversion between OpenAI APIs and native SRT APIs"""
|
"""Conversion between OpenAI APIs and native SRT APIs"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
from http import HTTPStatus
|
||||||
|
|
||||||
from fastapi import HTTPException, Request
|
from fastapi import Request
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse, JSONResponse
|
||||||
|
|
||||||
from sglang.srt.conversation import (
|
from sglang.srt.conversation import (
|
||||||
Conversation,
|
Conversation,
|
||||||
@@ -27,14 +29,36 @@ from sglang.srt.openai_protocol import (
|
|||||||
CompletionResponseStreamChoice,
|
CompletionResponseStreamChoice,
|
||||||
CompletionStreamResponse,
|
CompletionStreamResponse,
|
||||||
DeltaMessage,
|
DeltaMessage,
|
||||||
|
ErrorResponse,
|
||||||
LogProbs,
|
LogProbs,
|
||||||
UsageInfo,
|
UsageInfo,
|
||||||
)
|
)
|
||||||
from sglang.srt.utils import jsonify_pydantic_model
|
|
||||||
|
|
||||||
chat_template_name = None
|
chat_template_name = None
|
||||||
|
|
||||||
|
|
||||||
|
def create_error_response(
|
||||||
|
message: str,
|
||||||
|
err_type: str = "BadRequestError",
|
||||||
|
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST):
|
||||||
|
error = ErrorResponse(message=message,
|
||||||
|
type=err_type,
|
||||||
|
code=status_code.value)
|
||||||
|
return JSONResponse(content=error.model_dump(),
|
||||||
|
status_code=error.code)
|
||||||
|
|
||||||
|
|
||||||
|
def create_streaming_error_response(
|
||||||
|
message: str,
|
||||||
|
err_type: str = "BadRequestError",
|
||||||
|
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str:
|
||||||
|
error = ErrorResponse(message=message,
|
||||||
|
type=err_type,
|
||||||
|
code=status_code.value)
|
||||||
|
json_str = json.dumps({"error": error.model_dump()})
|
||||||
|
return json_str
|
||||||
|
|
||||||
|
|
||||||
def load_chat_template_for_openai_api(chat_template_arg):
|
def load_chat_template_for_openai_api(chat_template_arg):
|
||||||
global chat_template_name
|
global chat_template_name
|
||||||
|
|
||||||
@@ -74,8 +98,8 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|||||||
request_json = await raw_request.json()
|
request_json = await raw_request.json()
|
||||||
request = CompletionRequest(**request_json)
|
request = CompletionRequest(**request_json)
|
||||||
|
|
||||||
# TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid.
|
if request.n != 1:
|
||||||
assert request.n == 1
|
return create_error_response("n != 1 is not supported")
|
||||||
|
|
||||||
adapted_request = GenerateReqInput(
|
adapted_request = GenerateReqInput(
|
||||||
text=request.prompt,
|
text=request.prompt,
|
||||||
@@ -93,79 +117,88 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|||||||
return_text_in_logprobs=True,
|
return_text_in_logprobs=True,
|
||||||
stream=request.stream,
|
stream=request.stream,
|
||||||
)
|
)
|
||||||
adapted_request.post_init()
|
|
||||||
|
|
||||||
if adapted_request.stream:
|
if adapted_request.stream:
|
||||||
|
|
||||||
async def generate_stream_resp():
|
async def generate_stream_resp():
|
||||||
stream_buffer = ""
|
stream_buffer = ""
|
||||||
n_prev_token = 0
|
n_prev_token = 0
|
||||||
async for content in tokenizer_manager.generate_request(adapted_request):
|
try:
|
||||||
text = content["text"]
|
async for content in tokenizer_manager.generate_request(
|
||||||
prompt_tokens = content["meta_info"]["prompt_tokens"]
|
adapted_request, raw_request):
|
||||||
completion_tokens = content["meta_info"]["completion_tokens"]
|
text = content["text"]
|
||||||
|
prompt_tokens = content["meta_info"]["prompt_tokens"]
|
||||||
|
completion_tokens = content["meta_info"]["completion_tokens"]
|
||||||
|
|
||||||
if not stream_buffer: # The first chunk
|
if not stream_buffer: # The first chunk
|
||||||
if request.echo:
|
if request.echo:
|
||||||
# Prepend prompt in response text.
|
# Prepend prompt in response text.
|
||||||
text = request.prompt + text
|
text = request.prompt + text
|
||||||
|
|
||||||
if request.logprobs:
|
if request.logprobs:
|
||||||
# The first chunk and echo is enabled.
|
# The first chunk and echo is enabled.
|
||||||
if not stream_buffer and request.echo:
|
if not stream_buffer and request.echo:
|
||||||
prefill_token_logprobs = content["meta_info"][
|
prefill_token_logprobs = content["meta_info"][
|
||||||
"prefill_token_logprobs"
|
"prefill_token_logprobs"
|
||||||
]
|
]
|
||||||
prefill_top_logprobs = content["meta_info"][
|
prefill_top_logprobs = content["meta_info"][
|
||||||
"prefill_top_logprobs"
|
"prefill_top_logprobs"
|
||||||
]
|
]
|
||||||
|
else:
|
||||||
|
prefill_token_logprobs = None
|
||||||
|
prefill_top_logprobs = None
|
||||||
|
|
||||||
|
logprobs = to_openai_style_logprobs(
|
||||||
|
prefill_token_logprobs=prefill_token_logprobs,
|
||||||
|
prefill_top_logprobs=prefill_top_logprobs,
|
||||||
|
decode_token_logprobs=content["meta_info"][
|
||||||
|
"decode_token_logprobs"
|
||||||
|
][n_prev_token:],
|
||||||
|
decode_top_logprobs=content["meta_info"]["decode_top_logprobs"][
|
||||||
|
n_prev_token:
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
n_prev_token = len(content["meta_info"]["decode_token_logprobs"])
|
||||||
else:
|
else:
|
||||||
prefill_token_logprobs = None
|
logprobs = None
|
||||||
prefill_top_logprobs = None
|
|
||||||
|
|
||||||
logprobs = to_openai_style_logprobs(
|
delta = text[len(stream_buffer) :]
|
||||||
prefill_token_logprobs=prefill_token_logprobs,
|
stream_buffer = content["text"]
|
||||||
prefill_top_logprobs=prefill_top_logprobs,
|
choice_data = CompletionResponseStreamChoice(
|
||||||
decode_token_logprobs=content["meta_info"][
|
index=0,
|
||||||
"decode_token_logprobs"
|
text=delta,
|
||||||
][n_prev_token:],
|
logprobs=logprobs,
|
||||||
decode_top_logprobs=content["meta_info"]["decode_top_logprobs"][
|
finish_reason=content["meta_info"]["finish_reason"],
|
||||||
n_prev_token:
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
chunk = CompletionStreamResponse(
|
||||||
n_prev_token = len(content["meta_info"]["decode_token_logprobs"])
|
id=content["meta_info"]["id"],
|
||||||
else:
|
object="text_completion",
|
||||||
logprobs = None
|
choices=[choice_data],
|
||||||
|
model=request.model,
|
||||||
delta = text[len(stream_buffer) :]
|
usage=UsageInfo(
|
||||||
stream_buffer = content["text"]
|
prompt_tokens=prompt_tokens,
|
||||||
choice_data = CompletionResponseStreamChoice(
|
completion_tokens=completion_tokens,
|
||||||
index=0,
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
text=delta,
|
),
|
||||||
logprobs=logprobs,
|
)
|
||||||
finish_reason=content["meta_info"]["finish_reason"],
|
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||||
)
|
except ValueError as e:
|
||||||
chunk = CompletionStreamResponse(
|
error = create_streaming_error_response(str(e))
|
||||||
id=content["meta_info"]["id"],
|
yield f"data: {error}\n\n"
|
||||||
object="text_completion",
|
|
||||||
choices=[choice_data],
|
|
||||||
model=request.model,
|
|
||||||
usage=UsageInfo(
|
|
||||||
prompt_tokens=prompt_tokens,
|
|
||||||
completion_tokens=completion_tokens,
|
|
||||||
total_tokens=prompt_tokens + completion_tokens,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
|
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
return StreamingResponse(generate_stream_resp(), media_type="text/event-stream")
|
return StreamingResponse(generate_stream_resp(), media_type="text/event-stream",
|
||||||
|
background=tokenizer_manager.create_abort_task(adapted_request))
|
||||||
|
|
||||||
# Non-streaming response.
|
# Non-streaming response.
|
||||||
ret = await tokenizer_manager.generate_request(adapted_request).__anext__()
|
try:
|
||||||
ret = ret[0] if isinstance(ret, list) else ret
|
ret = await tokenizer_manager.generate_request(
|
||||||
|
adapted_request, raw_request).__anext__()
|
||||||
|
except ValueError as e:
|
||||||
|
return create_error_response(str(e))
|
||||||
|
|
||||||
|
ret = ret[0] if isinstance(ret, list) else ret
|
||||||
prompt_tokens = ret["meta_info"]["prompt_tokens"]
|
prompt_tokens = ret["meta_info"]["prompt_tokens"]
|
||||||
completion_tokens = ret["meta_info"]["completion_tokens"]
|
completion_tokens = ret["meta_info"]["completion_tokens"]
|
||||||
text = ret["text"]
|
text = ret["text"]
|
||||||
@@ -212,8 +245,8 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|||||||
request_json = await raw_request.json()
|
request_json = await raw_request.json()
|
||||||
request = ChatCompletionRequest(**request_json)
|
request = ChatCompletionRequest(**request_json)
|
||||||
|
|
||||||
# TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid.
|
if request.n != 1:
|
||||||
assert request.n == 1
|
return create_error_response("n != 1 is not supported")
|
||||||
|
|
||||||
# Prep the data needed for the underlying GenerateReqInput:
|
# Prep the data needed for the underlying GenerateReqInput:
|
||||||
# - prompt: The full prompt string.
|
# - prompt: The full prompt string.
|
||||||
@@ -258,7 +291,6 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|||||||
},
|
},
|
||||||
stream=request.stream,
|
stream=request.stream,
|
||||||
)
|
)
|
||||||
adapted_request.post_init()
|
|
||||||
|
|
||||||
if adapted_request.stream:
|
if adapted_request.stream:
|
||||||
|
|
||||||
@@ -266,13 +298,29 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|||||||
is_first = True
|
is_first = True
|
||||||
|
|
||||||
stream_buffer = ""
|
stream_buffer = ""
|
||||||
async for content in tokenizer_manager.generate_request(adapted_request):
|
try:
|
||||||
if is_first:
|
async for content in tokenizer_manager.generate_request(adapted_request, raw_request):
|
||||||
# First chunk with role
|
if is_first:
|
||||||
is_first = False
|
# First chunk with role
|
||||||
|
is_first = False
|
||||||
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
|
index=0,
|
||||||
|
delta=DeltaMessage(role="assistant"),
|
||||||
|
finish_reason=content["meta_info"]["finish_reason"],
|
||||||
|
)
|
||||||
|
chunk = ChatCompletionStreamResponse(
|
||||||
|
id=content["meta_info"]["id"],
|
||||||
|
choices=[choice_data],
|
||||||
|
model=request.model,
|
||||||
|
)
|
||||||
|
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||||
|
|
||||||
|
text = content["text"]
|
||||||
|
delta = text[len(stream_buffer) :]
|
||||||
|
stream_buffer = text
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
index=0,
|
index=0,
|
||||||
delta=DeltaMessage(role="assistant"),
|
delta=DeltaMessage(content=delta),
|
||||||
finish_reason=content["meta_info"]["finish_reason"],
|
finish_reason=content["meta_info"]["finish_reason"],
|
||||||
)
|
)
|
||||||
chunk = ChatCompletionStreamResponse(
|
chunk = ChatCompletionStreamResponse(
|
||||||
@@ -280,28 +328,22 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|||||||
choices=[choice_data],
|
choices=[choice_data],
|
||||||
model=request.model,
|
model=request.model,
|
||||||
)
|
)
|
||||||
yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
|
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||||
|
except ValueError as e:
|
||||||
text = content["text"]
|
error = create_streaming_error_response(str(e))
|
||||||
delta = text[len(stream_buffer) :]
|
yield f"data: {error}\n\n"
|
||||||
stream_buffer = text
|
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
|
||||||
index=0,
|
|
||||||
delta=DeltaMessage(content=delta),
|
|
||||||
finish_reason=content["meta_info"]["finish_reason"],
|
|
||||||
)
|
|
||||||
chunk = ChatCompletionStreamResponse(
|
|
||||||
id=content["meta_info"]["id"],
|
|
||||||
choices=[choice_data],
|
|
||||||
model=request.model,
|
|
||||||
)
|
|
||||||
yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
|
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
return StreamingResponse(generate_stream_resp(), media_type="text/event-stream")
|
return StreamingResponse(generate_stream_resp(), media_type="text/event-stream",
|
||||||
|
background=tokenizer_manager.create_abort_task(adapted_request))
|
||||||
|
|
||||||
# Non-streaming response.
|
# Non-streaming response.
|
||||||
ret = await tokenizer_manager.generate_request(adapted_request).__anext__()
|
try:
|
||||||
|
ret = await tokenizer_manager.generate_request(
|
||||||
|
adapted_request, raw_request).__anext__()
|
||||||
|
except ValueError as e:
|
||||||
|
return create_error_response(str(e))
|
||||||
|
|
||||||
prompt_tokens = ret["meta_info"]["prompt_tokens"]
|
prompt_tokens = ret["meta_info"]["prompt_tokens"]
|
||||||
completion_tokens = ret["meta_info"]["completion_tokens"]
|
completion_tokens = ret["meta_info"]["completion_tokens"]
|
||||||
choice_data = ChatCompletionResponseChoice(
|
choice_data = ChatCompletionResponseChoice(
|
||||||
|
|||||||
@@ -7,6 +7,14 @@ from pydantic import BaseModel, Field
|
|||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
|
|
||||||
|
|
||||||
|
class ErrorResponse(BaseModel):
|
||||||
|
object: str = "error"
|
||||||
|
message: str
|
||||||
|
type: str
|
||||||
|
param: Optional[str] = None
|
||||||
|
code: int
|
||||||
|
|
||||||
|
|
||||||
class LogProbs(BaseModel):
|
class LogProbs(BaseModel):
|
||||||
text_offset: List[int] = Field(default_factory=list)
|
text_offset: List[int] = Field(default_factory=list)
|
||||||
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
||||||
|
|||||||
@@ -93,7 +93,8 @@ async def generate_request(obj: GenerateReqInput, request: Request):
|
|||||||
yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
|
yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
return StreamingResponse(stream_results(), media_type="text/event-stream")
|
return StreamingResponse(stream_results(), media_type="text/event-stream",
|
||||||
|
background=tokenizer_manager.create_abort_task(obj))
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
||||||
|
|||||||
@@ -392,14 +392,4 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
|
|||||||
content={"detail": "Invalid API Key"},
|
content={"detail": "Invalid API Key"},
|
||||||
)
|
)
|
||||||
response = await call_next(request)
|
response = await call_next(request)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
# FIXME: Remove this once we drop support for pydantic 1.x
|
|
||||||
IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1
|
|
||||||
|
|
||||||
|
|
||||||
def jsonify_pydantic_model(obj: BaseModel):
|
|
||||||
if IS_PYDANTIC_1:
|
|
||||||
return obj.json(ensure_ascii=False)
|
|
||||||
return obj.model_dump_json()
|
|
||||||
Reference in New Issue
Block a user