Support stream=True in v1/completions (#49)

This commit is contained in:
Cody Yu
2024-01-18 17:00:56 -08:00
committed by GitHub
parent 98a3e8ef78
commit 61d4c93962
7 changed files with 233 additions and 39 deletions

View File

@@ -19,7 +19,7 @@ dependencies = [
[project.optional-dependencies]
srt = ["fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn", "zmq", "vllm>=0.2.5",
"interegular", "lark", "numba"]
"interegular", "lark", "numba", "pydantic"]
openai = ["openai>=1.0"]
anthropic = ["anthropic"]
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"]

View File

@@ -116,9 +116,12 @@ class RuntimeEndpoint(BaseBackend):
pos = 0
incomplete_text = ""
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
for chunk in response.iter_lines(decode_unicode=False):
chunk = chunk.decode("utf-8")
if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]":
break
data = json.loads(chunk[5:].strip("\n"))
text = find_printable_text(data["text"][pos:])
meta_info = data["meta_info"]
pos += len(text)

View File

@@ -1,12 +1,67 @@
from dataclasses import dataclass
from typing import Any, List, Optional, Union
import time
from typing import Dict, List, Optional, Union
from pydantic import BaseModel, Field
@dataclass
class CompletionRequest:
prompt: Union[str, List[Any]]
model: str = "default"
temperature: Optional[float] = 0.7
class LogProbs(BaseModel):
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list)
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)
class UsageInfo(BaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0
class CompletionRequest(BaseModel):
model: str
prompt: Union[str, List[str]]
suffix: Optional[str] = None
max_tokens: Optional[int] = 16
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
n: Optional[int] = 1
stop: Optional[Union[str, List[str]]] = None
stream: Optional[bool] = False
logprobs: Optional[int] = None
echo: Optional[bool] = False
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
best_of: Optional[int] = None
logit_bias: Optional[Dict[str, float]] = None
user: Optional[str] = None
class CompletionResponseChoice(BaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
finish_reason: Optional[str] = None
class CompletionResponse(BaseModel):
id: str
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[CompletionResponseChoice]
usage: UsageInfo
class CompletionResponseStreamChoice(BaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
finish_reason: Optional[str] = None
class CompletionStreamResponse(BaseModel):
id: str
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[CompletionResponseStreamChoice]

View File

@@ -1,7 +1,5 @@
"""SRT: SGLang Runtime"""
import argparse
import asyncio
import dataclasses
import json
import multiprocessing as mp
import sys
@@ -16,12 +14,19 @@ import psutil
import requests
import uvicorn
import uvloop
from fastapi import FastAPI
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from sglang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.openai_protocol import CompletionRequest
from sglang.srt.managers.openai_protocol import (
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse,
UsageInfo
)
from sglang.srt.managers.router.manager import start_router_process
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.server_args import PortArgs, ServerArgs
@@ -41,39 +46,97 @@ async def get_model_info():
}
return result
async def stream_generator(obj):
async for out in tokenizer_manager.generate_request(obj):
yield out
@app.post("/generate")
async def generate_request(obj: GenerateReqInput):
obj.post_init()
result_generator = tokenizer_manager.generate_request(obj)
if obj.stream:
async def stream_results():
async for out in result_generator:
yield (json.dumps(out) + "\0").encode("utf-8")
async for out in stream_generator(obj):
yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(stream_results(), media_type="text/event-stream")
else:
ret = await result_generator.__anext__()
return ret
ret = await tokenizer_manager.generate_request(obj).__anext__()
return ret
@app.post("/v1/completions")
async def v1_completions(obj: CompletionRequest):
assert obj.n == 1
obj = GenerateReqInput(
text=obj.prompt,
async def v1_completions(raw_request: Request):
request_json = await raw_request.json()
request = CompletionRequest(**request_json)
# TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid.
assert request.n == 1
adapted_request = GenerateReqInput(
text=request.prompt,
sampling_params={
"temperature": obj.temperature,
"max_new_tokens": obj.max_tokens,
"stop": obj.stop,
"temperature": request.temperature,
"max_new_tokens": request.max_tokens,
"stop": request.stop,
"top_p": request.top_p,
"presence_penalty": request.presence_penalty,
"frequency_penalty": request.frequency_penalty,
},
stream=request.stream,
)
ret = await generate_request(obj)
return {
"choices": [{"text": ret["text"]}],
}
adapted_request.post_init()
if adapted_request.stream:
async def gnerate_stream_resp():
stream_buffer = ""
async for content in stream_generator(adapted_request):
text = content["text"]
delta = text[len(stream_buffer):]
stream_buffer = text
choice_data = CompletionResponseStreamChoice(
index=0,
text=delta,
logprobs=None,
finish_reason=None,
)
chunk = CompletionStreamResponse(
id=content["meta_info"]["id"],
object="text_completion",
choices=[choice_data],
model=request.model,
)
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
return StreamingResponse(gnerate_stream_resp(), media_type="text/event-stream")
# Non-streaming response.
ret = await generate_request(adapted_request)
choice_data = CompletionResponseChoice(
index=0,
text=ret["text"],
logprobs=None,
finish_reason=None, # TODO(comaniac): Add finish reason.
)
prompt_tokens = ret["meta_info"]["prompt_tokens"]
completion_tokens = ret["meta_info"]["completion_tokens"]
response = CompletionResponse(
id=ret["meta_info"]["id"],
model=request.model,
choices=[choice_data],
usage=UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
return response
def launch_server(server_args, pipe_finish_writer):