Support v1/chat/completions (#50)

This commit is contained in:
Cody Yu
2024-01-18 23:43:09 -08:00
committed by GitHub
parent 61d4c93962
commit 23471f9aa3
6 changed files with 705 additions and 9 deletions

View File

@@ -2,6 +2,7 @@
import asyncio
import json
import multiprocessing as mp
import os
import sys
import threading
import time
@@ -17,15 +18,29 @@ import uvloop
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from sglang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.srt.conversation import (
Conversation,
SeparatorStyle,
chat_template_exists,
generate_chat_conv,
register_conv_template,
)
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.openai_protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
ChatMessage,
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse,
UsageInfo
DeltaMessage,
UsageInfo,
)
from sglang.srt.managers.router.manager import start_router_process
from sglang.srt.managers.tokenizer_manager import TokenizerManager
@@ -37,6 +52,7 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
app = FastAPI()
tokenizer_manager = None
chat_template_name = None
@app.get("/get_model_info")
@@ -46,6 +62,7 @@ async def get_model_info():
}
return result
async def stream_generator(obj):
async for out in tokenizer_manager.generate_request(obj):
yield out
@@ -61,7 +78,7 @@ async def generate_request(obj: GenerateReqInput):
async for out in stream_generator(obj):
yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(stream_results(), media_type="text/event-stream")
ret = await tokenizer_manager.generate_request(obj).__anext__()
@@ -91,11 +108,15 @@ async def v1_completions(raw_request: Request):
adapted_request.post_init()
if adapted_request.stream:
async def gnerate_stream_resp():
stream_buffer = ""
async for content in stream_generator(adapted_request):
text = content["text"]
delta = text[len(stream_buffer):]
prompt_tokens = content["meta_info"]["prompt_tokens"]
completion_tokens = content["meta_info"]["completion_tokens"]
delta = text[len(stream_buffer) :]
stream_buffer = text
choice_data = CompletionResponseStreamChoice(
index=0,
@@ -108,12 +129,17 @@ async def v1_completions(raw_request: Request):
object="text_completion",
choices=[choice_data],
model=request.model,
usage=UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(gnerate_stream_resp(), media_type="text/event-stream")
# Non-streaming response.
ret = await generate_request(adapted_request)
@@ -121,7 +147,7 @@ async def v1_completions(raw_request: Request):
index=0,
text=ret["text"],
logprobs=None,
finish_reason=None, # TODO(comaniac): Add finish reason.
finish_reason=None, # TODO(comaniac): Add finish reason.
)
prompt_tokens = ret["meta_info"]["prompt_tokens"]
@@ -139,8 +165,108 @@ async def v1_completions(raw_request: Request):
return response
@app.post("/v1/chat/completions")
async def v1_chat_completions(raw_request: Request):
request_json = await raw_request.json()
request = ChatCompletionRequest(**request_json)
# TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid.
assert request.n == 1
if not isinstance(request.messages, str):
# Apply chat template and its stop strings.
if chat_template_name is None:
prompt = tokenizer_manager.tokenizer.apply_chat_template(
request.messages, tokenize=False, add_generation_prompt=True
)
stop = request.stop
else:
conv = generate_chat_conv(request, chat_template_name)
prompt = conv.get_prompt()
stop = conv.stop_str or []
if request.stop:
if isinstance(request.stop, str):
stop.append(request.stop)
else:
stop.extend(request.stop)
else:
# Use the raw prompt and stop strings if the messages is already a string.
prompt = request.messages
stop = request.stop
adapted_request = GenerateReqInput(
text=prompt,
sampling_params={
"temperature": request.temperature,
"max_new_tokens": request.max_tokens,
"stop": stop,
"top_p": request.top_p,
"presence_penalty": request.presence_penalty,
"frequency_penalty": request.frequency_penalty,
},
stream=request.stream,
)
adapted_request.post_init()
if adapted_request.stream:
async def gnerate_stream_resp():
is_first = True
stream_buffer = ""
async for content in stream_generator(adapted_request):
if is_first:
# First chunk with role
is_first = False
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant"),
finish_reason=None,
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"], choices=[choice_data], model=request.model
)
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
text = content["text"]
delta = text[len(stream_buffer) :]
stream_buffer = text
choice_data = ChatCompletionResponseStreamChoice(
index=0, delta=DeltaMessage(content=delta), finish_reason=None
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"], choices=[choice_data], model=request.model
)
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(gnerate_stream_resp(), media_type="text/event-stream")
# Non-streaming response.
ret = await generate_request(adapted_request)
prompt_tokens = ret["meta_info"]["prompt_tokens"]
completion_tokens = ret["meta_info"]["completion_tokens"]
choice_data = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=ret["text"]),
finish_reason=None, # TODO(comaniac): Add finish reason.
)
response = ChatCompletionResponse(
id=ret["meta_info"]["id"],
model=request.model,
choices=[choice_data],
usage=UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
return response
def launch_server(server_args, pipe_finish_writer):
global tokenizer_manager
global chat_template_name
# Allocate ports
can_use_ports = alloc_usable_network_port(
@@ -154,6 +280,36 @@ def launch_server(server_args, pipe_finish_writer):
model_rpc_ports=can_use_ports[4:],
)
# Load chat template if needed
if server_args.chat_template is not None:
if not chat_template_exists(server_args.chat_template):
if not os.path.exists(server_args.chat_template):
raise RuntimeError(
f"Chat template {server_args.chat_template} is not a built-in template name "
"or a valid chat template file path."
)
with open(server_args.chat_template, "r") as filep:
template = json.load(filep)
try:
sep_style = SeparatorStyle[template["sep_style"]]
except KeyError:
raise ValueError(f"Unknown separator style: {template['sep_style']}") from None
register_conv_template(
Conversation(
name=template["name"],
system_template=template["system"] + "\n{system_message}",
system_message=template.get("system_message", ""),
roles=(template["user"], template["assistant"]),
sep_style=sep_style,
sep=template.get("sep", "\n"),
stop_str=template["stop_str"],
),
override=True,
)
chat_template_name = template["name"]
else:
chat_template_name = server_args.chat_template
# Launch processes
tokenizer_manager = TokenizerManager(server_args, port_args)
pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False)