Improve docs & Add JSON decode example (#121)
This commit is contained in:
@@ -20,7 +20,7 @@ dependencies = [
|
||||
[project.optional-dependencies]
|
||||
srt = ["aiohttp", "fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn",
|
||||
"zmq", "vllm>=0.2.5", "interegular", "lark", "numba",
|
||||
"pydantic", "diskcache", "cloudpickle", "pillow"]
|
||||
"pydantic", "referencing", "diskcache", "cloudpickle", "pillow"]
|
||||
openai = ["openai>=1.0", "numpy"]
|
||||
anthropic = ["anthropic", "numpy"]
|
||||
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"]
|
||||
|
||||
@@ -30,21 +30,8 @@ def create_logit_bias_int(tokenizer):
|
||||
return mask
|
||||
|
||||
|
||||
CHAT_MODEL_NAMES = [
|
||||
# GPT-4
|
||||
"gpt-4",
|
||||
"gpt-4-32k",
|
||||
"gpt-4-1106-preview",
|
||||
"gpt-4-vision-preview",
|
||||
"gpt-4-0613",
|
||||
"gpt-4-0314",
|
||||
# GPT-3.5
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-16k",
|
||||
"gpt-3.5-turbo-1106",
|
||||
"gpt-3.5-turbo-16k-0613",
|
||||
"gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo-0301",
|
||||
INSTRUCT_MODEL_NAMES = [
|
||||
"gpt-3.5-turbo-instruct",
|
||||
]
|
||||
|
||||
|
||||
@@ -60,10 +47,10 @@ class OpenAI(BaseBackend):
|
||||
self.tokenizer = tiktoken.encoding_for_model(model_name)
|
||||
self.logit_bias_int = create_logit_bias_int(self.tokenizer)
|
||||
|
||||
if model_name in CHAT_MODEL_NAMES:
|
||||
self.is_chat_model = True
|
||||
else:
|
||||
if model_name in INSTRUCT_MODEL_NAMES:
|
||||
self.is_chat_model = False
|
||||
else:
|
||||
self.is_chat_model = True
|
||||
|
||||
self.chat_template = get_chat_template("default")
|
||||
|
||||
@@ -235,6 +222,8 @@ def openai_completion(client, is_chat=None, prompt=None, **kwargs):
|
||||
def openai_completion_stream(client, is_chat=None, prompt=None, **kwargs):
|
||||
try:
|
||||
if is_chat:
|
||||
if kwargs["stop"] is None:
|
||||
kwargs.pop("stop")
|
||||
generator = client.chat.completions.create(
|
||||
messages=prompt, stream=True, **kwargs
|
||||
)
|
||||
|
||||
@@ -7,12 +7,19 @@ from sglang.srt.sampling_params import SamplingParams
|
||||
|
||||
@dataclass
|
||||
class GenerateReqInput:
|
||||
# The input prompt
|
||||
text: Union[List[str], str]
|
||||
# The image input
|
||||
image_data: Optional[Union[List[str], str]] = None
|
||||
# The sampling_params
|
||||
sampling_params: Union[List[Dict], Dict] = None
|
||||
# The request id
|
||||
rid: Optional[Union[List[str], str]] = None
|
||||
# Whether return logprobs of the prompts
|
||||
return_logprob: Optional[Union[List[bool], bool]] = None
|
||||
# The start location of the prompt for return_logprob
|
||||
logprob_start_len: Optional[Union[List[int], int]] = None
|
||||
# Whether to stream output
|
||||
stream: bool = False
|
||||
|
||||
def post_init(self):
|
||||
|
||||
@@ -296,7 +296,7 @@ def test_parallel_encoding(check_answer=True):
|
||||
def test_image_qa():
|
||||
@sgl.function
|
||||
def image_qa(s, question):
|
||||
s += sgl.user(sgl.image("image.png") + question)
|
||||
s += sgl.user(sgl.image("test_image.png") + question)
|
||||
s += sgl.assistant(sgl.gen("answer"))
|
||||
|
||||
state = image_qa.run(
|
||||
|
||||
Reference in New Issue
Block a user