Support stream=True in v1/completions (#49)
This commit is contained in:
@@ -19,7 +19,7 @@ dependencies = [
|
||||
|
||||
[project.optional-dependencies]
|
||||
srt = ["fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn", "zmq", "vllm>=0.2.5",
|
||||
"interegular", "lark", "numba"]
|
||||
"interegular", "lark", "numba", "pydantic"]
|
||||
openai = ["openai>=1.0"]
|
||||
anthropic = ["anthropic"]
|
||||
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"]
|
||||
|
||||
@@ -116,9 +116,12 @@ class RuntimeEndpoint(BaseBackend):
|
||||
pos = 0
|
||||
|
||||
incomplete_text = ""
|
||||
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||
if chunk:
|
||||
data = json.loads(chunk.decode())
|
||||
for chunk in response.iter_lines(decode_unicode=False):
|
||||
chunk = chunk.decode("utf-8")
|
||||
if chunk and chunk.startswith("data:"):
|
||||
if chunk == "data: [DONE]":
|
||||
break
|
||||
data = json.loads(chunk[5:].strip("\n"))
|
||||
text = find_printable_text(data["text"][pos:])
|
||||
meta_info = data["meta_info"]
|
||||
pos += len(text)
|
||||
|
||||
@@ -1,12 +1,67 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List, Optional, Union
|
||||
import time
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompletionRequest:
|
||||
prompt: Union[str, List[Any]]
|
||||
model: str = "default"
|
||||
temperature: Optional[float] = 0.7
|
||||
class LogProbs(BaseModel):
|
||||
text_offset: List[int] = Field(default_factory=list)
|
||||
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
||||
tokens: List[str] = Field(default_factory=list)
|
||||
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class UsageInfo(BaseModel):
|
||||
prompt_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
completion_tokens: Optional[int] = 0
|
||||
|
||||
|
||||
class CompletionRequest(BaseModel):
|
||||
model: str
|
||||
prompt: Union[str, List[str]]
|
||||
suffix: Optional[str] = None
|
||||
max_tokens: Optional[int] = 16
|
||||
temperature: Optional[float] = 0.7
|
||||
top_p: Optional[float] = 1.0
|
||||
n: Optional[int] = 1
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
stream: Optional[bool] = False
|
||||
logprobs: Optional[int] = None
|
||||
echo: Optional[bool] = False
|
||||
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
||||
presence_penalty: Optional[float] = 0.0
|
||||
frequency_penalty: Optional[float] = 0.0
|
||||
best_of: Optional[int] = None
|
||||
logit_bias: Optional[Dict[str, float]] = None
|
||||
user: Optional[str] = None
|
||||
|
||||
|
||||
class CompletionResponseChoice(BaseModel):
|
||||
index: int
|
||||
text: str
|
||||
logprobs: Optional[LogProbs] = None
|
||||
finish_reason: Optional[str] = None
|
||||
|
||||
|
||||
class CompletionResponse(BaseModel):
|
||||
id: str
|
||||
object: str = "text_completion"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[CompletionResponseChoice]
|
||||
usage: UsageInfo
|
||||
|
||||
|
||||
class CompletionResponseStreamChoice(BaseModel):
|
||||
index: int
|
||||
text: str
|
||||
logprobs: Optional[LogProbs] = None
|
||||
finish_reason: Optional[str] = None
|
||||
|
||||
|
||||
class CompletionStreamResponse(BaseModel):
|
||||
id: str
|
||||
object: str = "text_completion"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[CompletionResponseStreamChoice]
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
"""SRT: SGLang Runtime"""
|
||||
import argparse
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import json
|
||||
import multiprocessing as mp
|
||||
import sys
|
||||
@@ -16,12 +14,19 @@ import psutil
|
||||
import requests
|
||||
import uvicorn
|
||||
import uvloop
|
||||
from fastapi import FastAPI
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
||||
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
|
||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||
from sglang.srt.managers.openai_protocol import CompletionRequest
|
||||
from sglang.srt.managers.openai_protocol import (
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseChoice,
|
||||
CompletionResponseStreamChoice,
|
||||
CompletionStreamResponse,
|
||||
UsageInfo
|
||||
)
|
||||
from sglang.srt.managers.router.manager import start_router_process
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
@@ -41,39 +46,97 @@ async def get_model_info():
|
||||
}
|
||||
return result
|
||||
|
||||
async def stream_generator(obj):
|
||||
async for out in tokenizer_manager.generate_request(obj):
|
||||
yield out
|
||||
|
||||
|
||||
@app.post("/generate")
|
||||
async def generate_request(obj: GenerateReqInput):
|
||||
obj.post_init()
|
||||
result_generator = tokenizer_manager.generate_request(obj)
|
||||
|
||||
if obj.stream:
|
||||
|
||||
async def stream_results():
|
||||
async for out in result_generator:
|
||||
yield (json.dumps(out) + "\0").encode("utf-8")
|
||||
|
||||
async for out in stream_generator(obj):
|
||||
yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(stream_results(), media_type="text/event-stream")
|
||||
else:
|
||||
ret = await result_generator.__anext__()
|
||||
return ret
|
||||
|
||||
ret = await tokenizer_manager.generate_request(obj).__anext__()
|
||||
return ret
|
||||
|
||||
|
||||
@app.post("/v1/completions")
|
||||
async def v1_completions(obj: CompletionRequest):
|
||||
assert obj.n == 1
|
||||
obj = GenerateReqInput(
|
||||
text=obj.prompt,
|
||||
async def v1_completions(raw_request: Request):
|
||||
request_json = await raw_request.json()
|
||||
request = CompletionRequest(**request_json)
|
||||
|
||||
# TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid.
|
||||
assert request.n == 1
|
||||
|
||||
adapted_request = GenerateReqInput(
|
||||
text=request.prompt,
|
||||
sampling_params={
|
||||
"temperature": obj.temperature,
|
||||
"max_new_tokens": obj.max_tokens,
|
||||
"stop": obj.stop,
|
||||
"temperature": request.temperature,
|
||||
"max_new_tokens": request.max_tokens,
|
||||
"stop": request.stop,
|
||||
"top_p": request.top_p,
|
||||
"presence_penalty": request.presence_penalty,
|
||||
"frequency_penalty": request.frequency_penalty,
|
||||
},
|
||||
stream=request.stream,
|
||||
)
|
||||
ret = await generate_request(obj)
|
||||
return {
|
||||
"choices": [{"text": ret["text"]}],
|
||||
}
|
||||
adapted_request.post_init()
|
||||
|
||||
if adapted_request.stream:
|
||||
async def gnerate_stream_resp():
|
||||
stream_buffer = ""
|
||||
async for content in stream_generator(adapted_request):
|
||||
text = content["text"]
|
||||
delta = text[len(stream_buffer):]
|
||||
stream_buffer = text
|
||||
choice_data = CompletionResponseStreamChoice(
|
||||
index=0,
|
||||
text=delta,
|
||||
logprobs=None,
|
||||
finish_reason=None,
|
||||
)
|
||||
chunk = CompletionStreamResponse(
|
||||
id=content["meta_info"]["id"],
|
||||
object="text_completion",
|
||||
choices=[choice_data],
|
||||
model=request.model,
|
||||
)
|
||||
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
|
||||
|
||||
return StreamingResponse(gnerate_stream_resp(), media_type="text/event-stream")
|
||||
|
||||
|
||||
# Non-streaming response.
|
||||
ret = await generate_request(adapted_request)
|
||||
|
||||
choice_data = CompletionResponseChoice(
|
||||
index=0,
|
||||
text=ret["text"],
|
||||
logprobs=None,
|
||||
finish_reason=None, # TODO(comaniac): Add finish reason.
|
||||
)
|
||||
|
||||
prompt_tokens = ret["meta_info"]["prompt_tokens"]
|
||||
completion_tokens = ret["meta_info"]["completion_tokens"]
|
||||
response = CompletionResponse(
|
||||
id=ret["meta_info"]["id"],
|
||||
model=request.model,
|
||||
choices=[choice_data],
|
||||
usage=UsageInfo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
),
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
def launch_server(server_args, pipe_finish_writer):
|
||||
|
||||
Reference in New Issue
Block a user