Support stream=True in v1/completions (#49)
This commit is contained in:
18
README.md
18
README.md
@@ -238,9 +238,25 @@ curl http://localhost:30000/generate \
|
|||||||
}
|
}
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
Learn more about the argument format [here](docs/sampling_params.md).
|
Learn more about the argument format [here](docs/sampling_params.md).
|
||||||
|
|
||||||
|
### OpenAI Compatible API
|
||||||
|
|
||||||
|
In addition, the server supports an experimental OpenAI-compatible API.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import openai
|
||||||
|
client = openai.Client(
|
||||||
|
base_url="http://127.0.0.1:30000/v1", api_key="EMPTY")
|
||||||
|
response = client.completions.create(
|
||||||
|
model="default",
|
||||||
|
prompt="The capital of France is",
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=32,
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
|
||||||
### Additional Arguments
|
### Additional Arguments
|
||||||
- Add `--tp 2` to enable tensor parallelism.
|
- Add `--tp 2` to enable tensor parallelism.
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ dependencies = [
|
|||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
srt = ["fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn", "zmq", "vllm>=0.2.5",
|
srt = ["fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn", "zmq", "vllm>=0.2.5",
|
||||||
"interegular", "lark", "numba"]
|
"interegular", "lark", "numba", "pydantic"]
|
||||||
openai = ["openai>=1.0"]
|
openai = ["openai>=1.0"]
|
||||||
anthropic = ["anthropic"]
|
anthropic = ["anthropic"]
|
||||||
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"]
|
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"]
|
||||||
|
|||||||
@@ -116,9 +116,12 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
pos = 0
|
pos = 0
|
||||||
|
|
||||||
incomplete_text = ""
|
incomplete_text = ""
|
||||||
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
for chunk in response.iter_lines(decode_unicode=False):
|
||||||
if chunk:
|
chunk = chunk.decode("utf-8")
|
||||||
data = json.loads(chunk.decode())
|
if chunk and chunk.startswith("data:"):
|
||||||
|
if chunk == "data: [DONE]":
|
||||||
|
break
|
||||||
|
data = json.loads(chunk[5:].strip("\n"))
|
||||||
text = find_printable_text(data["text"][pos:])
|
text = find_printable_text(data["text"][pos:])
|
||||||
meta_info = data["meta_info"]
|
meta_info = data["meta_info"]
|
||||||
pos += len(text)
|
pos += len(text)
|
||||||
|
|||||||
@@ -1,12 +1,67 @@
|
|||||||
from dataclasses import dataclass
|
import time
|
||||||
from typing import Any, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
class LogProbs(BaseModel):
|
||||||
class CompletionRequest:
|
text_offset: List[int] = Field(default_factory=list)
|
||||||
prompt: Union[str, List[Any]]
|
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
||||||
model: str = "default"
|
tokens: List[str] = Field(default_factory=list)
|
||||||
temperature: Optional[float] = 0.7
|
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class UsageInfo(BaseModel):
|
||||||
|
prompt_tokens: int = 0
|
||||||
|
total_tokens: int = 0
|
||||||
|
completion_tokens: Optional[int] = 0
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionRequest(BaseModel):
|
||||||
|
model: str
|
||||||
|
prompt: Union[str, List[str]]
|
||||||
|
suffix: Optional[str] = None
|
||||||
max_tokens: Optional[int] = 16
|
max_tokens: Optional[int] = 16
|
||||||
|
temperature: Optional[float] = 0.7
|
||||||
|
top_p: Optional[float] = 1.0
|
||||||
n: Optional[int] = 1
|
n: Optional[int] = 1
|
||||||
stop: Optional[Union[str, List[str]]] = None
|
stream: Optional[bool] = False
|
||||||
|
logprobs: Optional[int] = None
|
||||||
|
echo: Optional[bool] = False
|
||||||
|
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
||||||
|
presence_penalty: Optional[float] = 0.0
|
||||||
|
frequency_penalty: Optional[float] = 0.0
|
||||||
|
best_of: Optional[int] = None
|
||||||
|
logit_bias: Optional[Dict[str, float]] = None
|
||||||
|
user: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionResponseChoice(BaseModel):
|
||||||
|
index: int
|
||||||
|
text: str
|
||||||
|
logprobs: Optional[LogProbs] = None
|
||||||
|
finish_reason: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
object: str = "text_completion"
|
||||||
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
|
model: str
|
||||||
|
choices: List[CompletionResponseChoice]
|
||||||
|
usage: UsageInfo
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionResponseStreamChoice(BaseModel):
|
||||||
|
index: int
|
||||||
|
text: str
|
||||||
|
logprobs: Optional[LogProbs] = None
|
||||||
|
finish_reason: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionStreamResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
object: str = "text_completion"
|
||||||
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
|
model: str
|
||||||
|
choices: List[CompletionResponseStreamChoice]
|
||||||
|
|||||||
@@ -1,7 +1,5 @@
|
|||||||
"""SRT: SGLang Runtime"""
|
"""SRT: SGLang Runtime"""
|
||||||
import argparse
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import dataclasses
|
|
||||||
import json
|
import json
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
import sys
|
import sys
|
||||||
@@ -16,12 +14,19 @@ import psutil
|
|||||||
import requests
|
import requests
|
||||||
import uvicorn
|
import uvicorn
|
||||||
import uvloop
|
import uvloop
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI, Request
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
||||||
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
|
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
|
||||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||||
from sglang.srt.managers.openai_protocol import CompletionRequest
|
from sglang.srt.managers.openai_protocol import (
|
||||||
|
CompletionRequest,
|
||||||
|
CompletionResponse,
|
||||||
|
CompletionResponseChoice,
|
||||||
|
CompletionResponseStreamChoice,
|
||||||
|
CompletionStreamResponse,
|
||||||
|
UsageInfo
|
||||||
|
)
|
||||||
from sglang.srt.managers.router.manager import start_router_process
|
from sglang.srt.managers.router.manager import start_router_process
|
||||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
@@ -41,39 +46,97 @@ async def get_model_info():
|
|||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
async def stream_generator(obj):
|
||||||
|
async for out in tokenizer_manager.generate_request(obj):
|
||||||
|
yield out
|
||||||
|
|
||||||
|
|
||||||
@app.post("/generate")
|
@app.post("/generate")
|
||||||
async def generate_request(obj: GenerateReqInput):
|
async def generate_request(obj: GenerateReqInput):
|
||||||
obj.post_init()
|
obj.post_init()
|
||||||
result_generator = tokenizer_manager.generate_request(obj)
|
|
||||||
|
|
||||||
if obj.stream:
|
if obj.stream:
|
||||||
|
|
||||||
async def stream_results():
|
async def stream_results():
|
||||||
async for out in result_generator:
|
async for out in stream_generator(obj):
|
||||||
yield (json.dumps(out) + "\0").encode("utf-8")
|
yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
return StreamingResponse(stream_results(), media_type="text/event-stream")
|
return StreamingResponse(stream_results(), media_type="text/event-stream")
|
||||||
else:
|
|
||||||
ret = await result_generator.__anext__()
|
ret = await tokenizer_manager.generate_request(obj).__anext__()
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/completions")
|
@app.post("/v1/completions")
|
||||||
async def v1_completions(obj: CompletionRequest):
|
async def v1_completions(raw_request: Request):
|
||||||
assert obj.n == 1
|
request_json = await raw_request.json()
|
||||||
obj = GenerateReqInput(
|
request = CompletionRequest(**request_json)
|
||||||
text=obj.prompt,
|
|
||||||
|
# TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid.
|
||||||
|
assert request.n == 1
|
||||||
|
|
||||||
|
adapted_request = GenerateReqInput(
|
||||||
|
text=request.prompt,
|
||||||
sampling_params={
|
sampling_params={
|
||||||
"temperature": obj.temperature,
|
"temperature": request.temperature,
|
||||||
"max_new_tokens": obj.max_tokens,
|
"max_new_tokens": request.max_tokens,
|
||||||
"stop": obj.stop,
|
"stop": request.stop,
|
||||||
|
"top_p": request.top_p,
|
||||||
|
"presence_penalty": request.presence_penalty,
|
||||||
|
"frequency_penalty": request.frequency_penalty,
|
||||||
},
|
},
|
||||||
|
stream=request.stream,
|
||||||
)
|
)
|
||||||
ret = await generate_request(obj)
|
adapted_request.post_init()
|
||||||
return {
|
|
||||||
"choices": [{"text": ret["text"]}],
|
if adapted_request.stream:
|
||||||
}
|
async def gnerate_stream_resp():
|
||||||
|
stream_buffer = ""
|
||||||
|
async for content in stream_generator(adapted_request):
|
||||||
|
text = content["text"]
|
||||||
|
delta = text[len(stream_buffer):]
|
||||||
|
stream_buffer = text
|
||||||
|
choice_data = CompletionResponseStreamChoice(
|
||||||
|
index=0,
|
||||||
|
text=delta,
|
||||||
|
logprobs=None,
|
||||||
|
finish_reason=None,
|
||||||
|
)
|
||||||
|
chunk = CompletionStreamResponse(
|
||||||
|
id=content["meta_info"]["id"],
|
||||||
|
object="text_completion",
|
||||||
|
choices=[choice_data],
|
||||||
|
model=request.model,
|
||||||
|
)
|
||||||
|
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
return StreamingResponse(gnerate_stream_resp(), media_type="text/event-stream")
|
||||||
|
|
||||||
|
|
||||||
|
# Non-streaming response.
|
||||||
|
ret = await generate_request(adapted_request)
|
||||||
|
|
||||||
|
choice_data = CompletionResponseChoice(
|
||||||
|
index=0,
|
||||||
|
text=ret["text"],
|
||||||
|
logprobs=None,
|
||||||
|
finish_reason=None, # TODO(comaniac): Add finish reason.
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_tokens = ret["meta_info"]["prompt_tokens"]
|
||||||
|
completion_tokens = ret["meta_info"]["completion_tokens"]
|
||||||
|
response = CompletionResponse(
|
||||||
|
id=ret["meta_info"]["id"],
|
||||||
|
model=request.model,
|
||||||
|
choices=[choice_data],
|
||||||
|
usage=UsageInfo(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
def launch_server(server_args, pipe_finish_writer):
|
def launch_server(server_args, pipe_finish_writer):
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ if __name__ == "__main__":
|
|||||||
"text": "The capital of France is",
|
"text": "The capital of France is",
|
||||||
"sampling_params": {
|
"sampling_params": {
|
||||||
"temperature": 0,
|
"temperature": 0,
|
||||||
"max_new_tokens": 1024,
|
"max_new_tokens": 512,
|
||||||
},
|
},
|
||||||
"stream": True,
|
"stream": True,
|
||||||
},
|
},
|
||||||
@@ -33,9 +33,12 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
prev = 0
|
prev = 0
|
||||||
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
for chunk in response.iter_lines(decode_unicode=False):
|
||||||
if chunk:
|
chunk = chunk.decode("utf-8")
|
||||||
data = json.loads(chunk.decode())
|
if chunk and chunk.startswith("data:"):
|
||||||
|
if chunk == "data: [DONE]":
|
||||||
|
break
|
||||||
|
data = json.loads(chunk[5:].strip("\n"))
|
||||||
output = data["text"].strip()
|
output = data["text"].strip()
|
||||||
print(output[prev:], end="", flush=True)
|
print(output[prev:], end="", flush=True)
|
||||||
prev = len(output)
|
prev = len(output)
|
||||||
|
|||||||
54
test/srt/test_openai_server.py
Normal file
54
test/srt/test_openai_server.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
"""
|
||||||
|
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
|
||||||
|
|
||||||
|
Output:
|
||||||
|
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import openai
|
||||||
|
|
||||||
|
|
||||||
|
def test_completion(args):
|
||||||
|
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
|
||||||
|
response = client.completions.create(
|
||||||
|
model="default",
|
||||||
|
prompt="The capital of France is",
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=32,
|
||||||
|
)
|
||||||
|
print(response.choices[0].text)
|
||||||
|
assert response.id
|
||||||
|
assert response.created
|
||||||
|
assert response.usage.prompt_tokens > 0
|
||||||
|
assert response.usage.completion_tokens > 0
|
||||||
|
assert response.usage.total_tokens > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_completion_stream(args):
|
||||||
|
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
|
||||||
|
response = client.completions.create(
|
||||||
|
model="default",
|
||||||
|
prompt="The capital of France is",
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=32,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
for r in response:
|
||||||
|
print(r.choices[0].text, end="", flush=True)
|
||||||
|
assert r.id
|
||||||
|
assert r.created
|
||||||
|
assert r.usage.prompt_tokens > 0
|
||||||
|
assert r.usage.completion_tokens > 0
|
||||||
|
assert r.usage.total_tokens > 0
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--base-url", type=str, default="http://127.0.0.1:30000/v1")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
test_completion(args)
|
||||||
|
test_completion_stream(args)
|
||||||
Reference in New Issue
Block a user