Add Together and AzureOpenAI examples (#184)

This commit is contained in:
Lianmin Zheng
2024-02-12 01:06:38 -08:00
committed by GitHub
parent 931213245c
commit bb824da41a
8 changed files with 263 additions and 16 deletions

View File

@@ -4,7 +4,7 @@ from typing import Callable, List, Optional, Union
import numpy as np
from sglang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template
from sglang.lang.chat_template import get_chat_template_by_model_path, ChatTemplate
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SglSamplingParams
@@ -41,23 +41,39 @@ INSTRUCT_MODEL_NAMES = [
class OpenAI(BaseBackend):
def __init__(self, model_name, *args, **kwargs):
def __init__(self, model_name: str,
is_chat_model: Optional[bool] = None,
chat_template: Optional[ChatTemplate] = None,
is_azure: bool = False,
*args, **kwargs):
super().__init__()
if isinstance(openai, Exception):
raise openai
self.client = openai.OpenAI(*args, **kwargs)
if is_azure:
self.client = openai.AzureOpenAI(*args, **kwargs)
else:
self.client = openai.OpenAI(*args, **kwargs)
self.model_name = model_name
self.tokenizer = tiktoken.encoding_for_model(model_name)
try:
self.tokenizer = tiktoken.encoding_for_model(model_name)
except KeyError:
self.tokenizer = tiktoken.get_encoding("cl100k_base")
self.logit_bias_int = create_logit_bias_int(self.tokenizer)
if model_name in INSTRUCT_MODEL_NAMES:
self.is_chat_model = False
else:
self.is_chat_model = True
self.chat_template = chat_template or get_chat_template_by_model_path(model_name)
self.chat_template = get_chat_template("default")
if is_chat_model is not None:
self.is_chat_model = is_chat_model
else:
if model_name in INSTRUCT_MODEL_NAMES:
self.is_chat_model = False
else:
self.is_chat_model = True
self.chat_begin_str = self.chat_template.role_prefix_and_suffix["assistant"][0]
def get_chat_template(self):
return self.chat_template
@@ -69,7 +85,7 @@ class OpenAI(BaseBackend):
):
if sampling_params.dtype is None:
if self.is_chat_model:
if not s.text_.endswith("ASSISTANT:"):
if not s.text_.endswith(self.chat_begin_str):
raise RuntimeError(
"This use case is not supported. "
"For OpenAI chat models, sgl.gen must be right after sgl.assistant"
@@ -122,7 +138,11 @@ class OpenAI(BaseBackend):
):
if sampling_params.dtype is None:
if self.is_chat_model:
assert s.text_.endswith("ASSISTANT:")
if not s.text_.endswith(self.chat_begin_str):
raise RuntimeError(
"This use case is not supported. "
"For OpenAI chat models, sgl.gen must be right after sgl.assistant"
)
prompt = s.messages_
else:
prompt = s.text_
@@ -241,7 +261,10 @@ def openai_completion_stream(client, retries=3, is_chat=None, prompt=None, **kwa
messages=prompt, stream=True, **kwargs
)
for ret in generator:
content = ret.choices[0].delta.content
try:
content = ret.choices[0].delta.content
except IndexError:
content = None
yield content or "", {}
else:
generator = client.completions.create(