Add Together and AzureOpenAI examples (#184)
This commit is contained in:
@@ -4,7 +4,7 @@ from typing import Callable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from sglang.backend.base_backend import BaseBackend
|
||||
from sglang.lang.chat_template import get_chat_template
|
||||
from sglang.lang.chat_template import get_chat_template_by_model_path, ChatTemplate
|
||||
from sglang.lang.interpreter import StreamExecutor
|
||||
from sglang.lang.ir import SglSamplingParams
|
||||
|
||||
@@ -41,23 +41,39 @@ INSTRUCT_MODEL_NAMES = [
|
||||
|
||||
|
||||
class OpenAI(BaseBackend):
|
||||
def __init__(self, model_name, *args, **kwargs):
|
||||
def __init__(self, model_name: str,
|
||||
is_chat_model: Optional[bool] = None,
|
||||
chat_template: Optional[ChatTemplate] = None,
|
||||
is_azure: bool = False,
|
||||
*args, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
if isinstance(openai, Exception):
|
||||
raise openai
|
||||
|
||||
self.client = openai.OpenAI(*args, **kwargs)
|
||||
if is_azure:
|
||||
self.client = openai.AzureOpenAI(*args, **kwargs)
|
||||
else:
|
||||
self.client = openai.OpenAI(*args, **kwargs)
|
||||
|
||||
self.model_name = model_name
|
||||
self.tokenizer = tiktoken.encoding_for_model(model_name)
|
||||
try:
|
||||
self.tokenizer = tiktoken.encoding_for_model(model_name)
|
||||
except KeyError:
|
||||
self.tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||
self.logit_bias_int = create_logit_bias_int(self.tokenizer)
|
||||
|
||||
if model_name in INSTRUCT_MODEL_NAMES:
|
||||
self.is_chat_model = False
|
||||
else:
|
||||
self.is_chat_model = True
|
||||
self.chat_template = chat_template or get_chat_template_by_model_path(model_name)
|
||||
|
||||
self.chat_template = get_chat_template("default")
|
||||
if is_chat_model is not None:
|
||||
self.is_chat_model = is_chat_model
|
||||
else:
|
||||
if model_name in INSTRUCT_MODEL_NAMES:
|
||||
self.is_chat_model = False
|
||||
else:
|
||||
self.is_chat_model = True
|
||||
|
||||
self.chat_begin_str = self.chat_template.role_prefix_and_suffix["assistant"][0]
|
||||
|
||||
def get_chat_template(self):
|
||||
return self.chat_template
|
||||
@@ -69,7 +85,7 @@ class OpenAI(BaseBackend):
|
||||
):
|
||||
if sampling_params.dtype is None:
|
||||
if self.is_chat_model:
|
||||
if not s.text_.endswith("ASSISTANT:"):
|
||||
if not s.text_.endswith(self.chat_begin_str):
|
||||
raise RuntimeError(
|
||||
"This use case is not supported. "
|
||||
"For OpenAI chat models, sgl.gen must be right after sgl.assistant"
|
||||
@@ -122,7 +138,11 @@ class OpenAI(BaseBackend):
|
||||
):
|
||||
if sampling_params.dtype is None:
|
||||
if self.is_chat_model:
|
||||
assert s.text_.endswith("ASSISTANT:")
|
||||
if not s.text_.endswith(self.chat_begin_str):
|
||||
raise RuntimeError(
|
||||
"This use case is not supported. "
|
||||
"For OpenAI chat models, sgl.gen must be right after sgl.assistant"
|
||||
)
|
||||
prompt = s.messages_
|
||||
else:
|
||||
prompt = s.text_
|
||||
@@ -241,7 +261,10 @@ def openai_completion_stream(client, retries=3, is_chat=None, prompt=None, **kwa
|
||||
messages=prompt, stream=True, **kwargs
|
||||
)
|
||||
for ret in generator:
|
||||
content = ret.choices[0].delta.content
|
||||
try:
|
||||
content = ret.choices[0].delta.content
|
||||
except IndexError:
|
||||
content = None
|
||||
yield content or "", {}
|
||||
else:
|
||||
generator = client.completions.create(
|
||||
|
||||
Reference in New Issue
Block a user