Fix the error message and dependency of openai backend (#71)
This commit is contained in:
@@ -164,7 +164,8 @@ def image_qa(s, image_file, question):
|
||||
```
|
||||
|
||||
### Constrained Decoding
|
||||
Use `regex=` to specify a regular expression as a decoding constraint.
|
||||
Use `regex` to specify a regular expression as a decoding constraint.
|
||||
This is only supported for local models.
|
||||
|
||||
```python
|
||||
@sgl.function
|
||||
|
||||
@@ -18,10 +18,11 @@ dependencies = [
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
srt = ["fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn", "zmq", "vllm>=0.2.5",
|
||||
"interegular", "lark", "numba", "pydantic", "diskcache", "cloudpickle"]
|
||||
openai = ["openai>=1.0"]
|
||||
anthropic = ["anthropic"]
|
||||
srt = ["fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn",
|
||||
"zmq", "vllm>=0.2.5", "interegular", "lark", "numba",
|
||||
"pydantic", "diskcache", "cloudpickle"]
|
||||
openai = ["openai>=1.0", "numpy"]
|
||||
anthropic = ["anthropic", "numpy"]
|
||||
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"]
|
||||
|
||||
[project.urls]
|
||||
|
||||
@@ -77,7 +77,9 @@ class OpenAI(BaseBackend):
|
||||
):
|
||||
if sampling_params.dtype is None:
|
||||
if self.is_chat_model:
|
||||
assert s.text_.endswith("ASSISTANT:")
|
||||
if not s.text_.endswith("ASSISTANT:"):
|
||||
raise RuntimeError("This use case is not supported. "
|
||||
"For OpenAI chat models, sgl.gen must be right after sgl.assistant")
|
||||
prompt = s.messages_
|
||||
else:
|
||||
prompt = s.text_
|
||||
@@ -149,6 +151,12 @@ class OpenAI(BaseBackend):
|
||||
choices: List[str],
|
||||
temperature: float,
|
||||
):
|
||||
if self.is_chat_model:
|
||||
raise NotImplementedError(
|
||||
"select/choices is not supported for chat models. "
|
||||
"Please try to use a non-chat model such as gpt-3.5-turbo-instruct"
|
||||
)
|
||||
|
||||
n_choices = len(choices)
|
||||
token_ids = [self.tokenizer.encode(x) for x in choices]
|
||||
scores = [0] * n_choices
|
||||
|
||||
@@ -197,16 +197,7 @@ class StreamExecutor:
|
||||
self.stream_var_event = None
|
||||
|
||||
def submit(self, expr: SglExpr):
|
||||
if isinstance(expr, (SglGen, SglSelect, SglVarScopeBegin)):
|
||||
self.variable_event[expr.name] = threading.Event()
|
||||
if self.stream:
|
||||
self.stream_var_event[expr.name] = threading.Event()
|
||||
elif isinstance(expr, SglExprList):
|
||||
for e in expr.expr_list:
|
||||
if isinstance(e, (SglGen, SglSelect, SglVarScopeBegin)):
|
||||
self.variable_event[e.name] = threading.Event()
|
||||
if self.stream:
|
||||
self.stream_var_event[e.name] = threading.Event()
|
||||
self._init_var_event(expr)
|
||||
|
||||
if self.use_thread:
|
||||
self.queue.put(expr)
|
||||
@@ -467,6 +458,15 @@ class StreamExecutor:
|
||||
src_rids = [state.stream_executor.sid for state in expr.states]
|
||||
self.backend.concatenate_and_append(src_rids, self.sid)
|
||||
|
||||
def _init_var_event(self, expr):
|
||||
if isinstance(expr, (SglGen, SglSelect, SglVarScopeBegin)):
|
||||
self.variable_event[expr.name] = threading.Event()
|
||||
if self.stream:
|
||||
self.stream_var_event[expr.name] = threading.Event()
|
||||
elif isinstance(expr, SglExprList):
|
||||
for e in expr.expr_list:
|
||||
self._init_var_event(e)
|
||||
|
||||
def _resolve_sampling_params(self, sampling_params):
|
||||
clone = None
|
||||
for item in [
|
||||
|
||||
Reference in New Issue
Block a user