Improve docs & Add JSON decode example (#121)

This commit is contained in:
Lianmin Zheng
2024-01-30 05:45:27 -08:00
committed by GitHub
parent 0617528632
commit 97aa9b3284
19 changed files with 212 additions and 61 deletions

View File

@@ -20,7 +20,7 @@ dependencies = [
[project.optional-dependencies]
srt = ["aiohttp", "fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn",
"zmq", "vllm>=0.2.5", "interegular", "lark", "numba",
"pydantic", "diskcache", "cloudpickle", "pillow"]
"pydantic", "referencing", "diskcache", "cloudpickle", "pillow"]
openai = ["openai>=1.0", "numpy"]
anthropic = ["anthropic", "numpy"]
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"]

View File

@@ -30,21 +30,8 @@ def create_logit_bias_int(tokenizer):
return mask
CHAT_MODEL_NAMES = [
# GPT-4
"gpt-4",
"gpt-4-32k",
"gpt-4-1106-preview",
"gpt-4-vision-preview",
"gpt-4-0613",
"gpt-4-0314",
# GPT-3.5
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo-16k-0613",
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-0301",
INSTRUCT_MODEL_NAMES = [
"gpt-3.5-turbo-instruct",
]
@@ -60,10 +47,10 @@ class OpenAI(BaseBackend):
self.tokenizer = tiktoken.encoding_for_model(model_name)
self.logit_bias_int = create_logit_bias_int(self.tokenizer)
if model_name in CHAT_MODEL_NAMES:
self.is_chat_model = True
else:
if model_name in INSTRUCT_MODEL_NAMES:
self.is_chat_model = False
else:
self.is_chat_model = True
self.chat_template = get_chat_template("default")
@@ -235,6 +222,8 @@ def openai_completion(client, is_chat=None, prompt=None, **kwargs):
def openai_completion_stream(client, is_chat=None, prompt=None, **kwargs):
try:
if is_chat:
if kwargs["stop"] is None:
kwargs.pop("stop")
generator = client.chat.completions.create(
messages=prompt, stream=True, **kwargs
)

View File

@@ -7,12 +7,19 @@ from sglang.srt.sampling_params import SamplingParams
@dataclass
class GenerateReqInput:
# The input prompt
text: Union[List[str], str]
# The image input
image_data: Optional[Union[List[str], str]] = None
# The sampling_params
sampling_params: Union[List[Dict], Dict] = None
# The request id
rid: Optional[Union[List[str], str]] = None
# Whether return logprobs of the prompts
return_logprob: Optional[Union[List[bool], bool]] = None
# The start location of the prompt for return_logprob
logprob_start_len: Optional[Union[List[int], int]] = None
# Whether to stream output
stream: bool = False
def post_init(self):

View File

@@ -296,7 +296,7 @@ def test_parallel_encoding(check_answer=True):
def test_image_qa():
@sgl.function
def image_qa(s, question):
s += sgl.user(sgl.image("image.png") + question)
s += sgl.user(sgl.image("test_image.png") + question)
s += sgl.assistant(sgl.gen("answer"))
state = image_qa.run(