Support v1/chat/completions (#50)
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
import asyncio
|
||||
import json
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
@@ -17,15 +18,29 @@ import uvloop
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
||||
from sglang.srt.conversation import (
|
||||
Conversation,
|
||||
SeparatorStyle,
|
||||
chat_template_exists,
|
||||
generate_chat_conv,
|
||||
register_conv_template,
|
||||
)
|
||||
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
|
||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||
from sglang.srt.managers.openai_protocol import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseChoice,
|
||||
ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse,
|
||||
ChatMessage,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseChoice,
|
||||
CompletionResponseStreamChoice,
|
||||
CompletionStreamResponse,
|
||||
UsageInfo
|
||||
DeltaMessage,
|
||||
UsageInfo,
|
||||
)
|
||||
from sglang.srt.managers.router.manager import start_router_process
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
@@ -37,6 +52,7 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
app = FastAPI()
|
||||
tokenizer_manager = None
|
||||
chat_template_name = None
|
||||
|
||||
|
||||
@app.get("/get_model_info")
|
||||
@@ -46,6 +62,7 @@ async def get_model_info():
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
async def stream_generator(obj):
|
||||
async for out in tokenizer_manager.generate_request(obj):
|
||||
yield out
|
||||
@@ -61,7 +78,7 @@ async def generate_request(obj: GenerateReqInput):
|
||||
async for out in stream_generator(obj):
|
||||
yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
||||
return StreamingResponse(stream_results(), media_type="text/event-stream")
|
||||
|
||||
ret = await tokenizer_manager.generate_request(obj).__anext__()
|
||||
@@ -91,11 +108,15 @@ async def v1_completions(raw_request: Request):
|
||||
adapted_request.post_init()
|
||||
|
||||
if adapted_request.stream:
|
||||
|
||||
async def gnerate_stream_resp():
|
||||
stream_buffer = ""
|
||||
async for content in stream_generator(adapted_request):
|
||||
text = content["text"]
|
||||
delta = text[len(stream_buffer):]
|
||||
prompt_tokens = content["meta_info"]["prompt_tokens"]
|
||||
completion_tokens = content["meta_info"]["completion_tokens"]
|
||||
|
||||
delta = text[len(stream_buffer) :]
|
||||
stream_buffer = text
|
||||
choice_data = CompletionResponseStreamChoice(
|
||||
index=0,
|
||||
@@ -108,12 +129,17 @@ async def v1_completions(raw_request: Request):
|
||||
object="text_completion",
|
||||
choices=[choice_data],
|
||||
model=request.model,
|
||||
usage=UsageInfo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
),
|
||||
)
|
||||
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(gnerate_stream_resp(), media_type="text/event-stream")
|
||||
|
||||
|
||||
# Non-streaming response.
|
||||
ret = await generate_request(adapted_request)
|
||||
|
||||
@@ -121,7 +147,7 @@ async def v1_completions(raw_request: Request):
|
||||
index=0,
|
||||
text=ret["text"],
|
||||
logprobs=None,
|
||||
finish_reason=None, # TODO(comaniac): Add finish reason.
|
||||
finish_reason=None, # TODO(comaniac): Add finish reason.
|
||||
)
|
||||
|
||||
prompt_tokens = ret["meta_info"]["prompt_tokens"]
|
||||
@@ -139,8 +165,108 @@ async def v1_completions(raw_request: Request):
|
||||
return response
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def v1_chat_completions(raw_request: Request):
|
||||
request_json = await raw_request.json()
|
||||
request = ChatCompletionRequest(**request_json)
|
||||
|
||||
# TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid.
|
||||
assert request.n == 1
|
||||
|
||||
if not isinstance(request.messages, str):
|
||||
# Apply chat template and its stop strings.
|
||||
if chat_template_name is None:
|
||||
prompt = tokenizer_manager.tokenizer.apply_chat_template(
|
||||
request.messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
stop = request.stop
|
||||
else:
|
||||
conv = generate_chat_conv(request, chat_template_name)
|
||||
prompt = conv.get_prompt()
|
||||
stop = conv.stop_str or []
|
||||
if request.stop:
|
||||
if isinstance(request.stop, str):
|
||||
stop.append(request.stop)
|
||||
else:
|
||||
stop.extend(request.stop)
|
||||
else:
|
||||
# Use the raw prompt and stop strings if the messages is already a string.
|
||||
prompt = request.messages
|
||||
stop = request.stop
|
||||
|
||||
adapted_request = GenerateReqInput(
|
||||
text=prompt,
|
||||
sampling_params={
|
||||
"temperature": request.temperature,
|
||||
"max_new_tokens": request.max_tokens,
|
||||
"stop": stop,
|
||||
"top_p": request.top_p,
|
||||
"presence_penalty": request.presence_penalty,
|
||||
"frequency_penalty": request.frequency_penalty,
|
||||
},
|
||||
stream=request.stream,
|
||||
)
|
||||
adapted_request.post_init()
|
||||
|
||||
if adapted_request.stream:
|
||||
|
||||
async def gnerate_stream_resp():
|
||||
is_first = True
|
||||
|
||||
stream_buffer = ""
|
||||
async for content in stream_generator(adapted_request):
|
||||
if is_first:
|
||||
# First chunk with role
|
||||
is_first = False
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=DeltaMessage(role="assistant"),
|
||||
finish_reason=None,
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=content["meta_info"]["id"], choices=[choice_data], model=request.model
|
||||
)
|
||||
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
|
||||
|
||||
text = content["text"]
|
||||
delta = text[len(stream_buffer) :]
|
||||
stream_buffer = text
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0, delta=DeltaMessage(content=delta), finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=content["meta_info"]["id"], choices=[choice_data], model=request.model
|
||||
)
|
||||
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(gnerate_stream_resp(), media_type="text/event-stream")
|
||||
|
||||
# Non-streaming response.
|
||||
ret = await generate_request(adapted_request)
|
||||
prompt_tokens = ret["meta_info"]["prompt_tokens"]
|
||||
completion_tokens = ret["meta_info"]["completion_tokens"]
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
index=0,
|
||||
message=ChatMessage(role="assistant", content=ret["text"]),
|
||||
finish_reason=None, # TODO(comaniac): Add finish reason.
|
||||
)
|
||||
response = ChatCompletionResponse(
|
||||
id=ret["meta_info"]["id"],
|
||||
model=request.model,
|
||||
choices=[choice_data],
|
||||
usage=UsageInfo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
),
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
def launch_server(server_args, pipe_finish_writer):
|
||||
global tokenizer_manager
|
||||
global chat_template_name
|
||||
|
||||
# Allocate ports
|
||||
can_use_ports = alloc_usable_network_port(
|
||||
@@ -154,6 +280,36 @@ def launch_server(server_args, pipe_finish_writer):
|
||||
model_rpc_ports=can_use_ports[4:],
|
||||
)
|
||||
|
||||
# Load chat template if needed
|
||||
if server_args.chat_template is not None:
|
||||
if not chat_template_exists(server_args.chat_template):
|
||||
if not os.path.exists(server_args.chat_template):
|
||||
raise RuntimeError(
|
||||
f"Chat template {server_args.chat_template} is not a built-in template name "
|
||||
"or a valid chat template file path."
|
||||
)
|
||||
with open(server_args.chat_template, "r") as filep:
|
||||
template = json.load(filep)
|
||||
try:
|
||||
sep_style = SeparatorStyle[template["sep_style"]]
|
||||
except KeyError:
|
||||
raise ValueError(f"Unknown separator style: {template['sep_style']}") from None
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name=template["name"],
|
||||
system_template=template["system"] + "\n{system_message}",
|
||||
system_message=template.get("system_message", ""),
|
||||
roles=(template["user"], template["assistant"]),
|
||||
sep_style=sep_style,
|
||||
sep=template.get("sep", "\n"),
|
||||
stop_str=template["stop_str"],
|
||||
),
|
||||
override=True,
|
||||
)
|
||||
chat_template_name = template["name"]
|
||||
else:
|
||||
chat_template_name = server_args.chat_template
|
||||
|
||||
# Launch processes
|
||||
tokenizer_manager = TokenizerManager(server_args, port_args)
|
||||
pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False)
|
||||
|
||||
Reference in New Issue
Block a user