openai chat speculative execution (#250)
Co-authored-by: Ying Sheng <sqy1415@gmail.com>
This commit is contained in:
@@ -14,10 +14,25 @@ def gen_character_spec(s):
|
|||||||
s += "\nJob:" + gen("job", stop="\n") + "\n"
|
s += "\nJob:" + gen("job", stop="\n") + "\n"
|
||||||
|
|
||||||
|
|
||||||
|
@function(api_num_spec_tokens=64)
|
||||||
|
def gen_character_spec_no_few_shot(s):
|
||||||
|
# s += "Construct a character with name, birthday, and job:\n"
|
||||||
|
s += "Construct a character:\n"
|
||||||
|
s += "Name:" + gen("name", stop="\n") + "\nBirthday:" + gen("birthday", stop="\n")
|
||||||
|
s += "\nJob:" + gen("job", stop="\n") + "\n"
|
||||||
|
|
||||||
|
|
||||||
set_default_backend(OpenAI("gpt-3.5-turbo-instruct"))
|
set_default_backend(OpenAI("gpt-3.5-turbo-instruct"))
|
||||||
|
|
||||||
state = gen_character_spec.run()
|
state = gen_character_spec.run()
|
||||||
|
|
||||||
print("name:", state["name"])
|
print("...name:", state["name"])
|
||||||
print("birthday:", state["birthday"])
|
print("...birthday:", state["birthday"])
|
||||||
print("job:", state["job"])
|
print("...job:", state["job"])
|
||||||
|
|
||||||
|
state = gen_character_spec_no_few_shot.run()
|
||||||
|
|
||||||
|
print("\n...name:", state["name"])
|
||||||
|
print("...birthday:", state["birthday"])
|
||||||
|
print("...job:", state["job"])
|
||||||
|
|
||||||
|
|||||||
123
examples/usage/openaichat_speculative.py
Normal file
123
examples/usage/openaichat_speculative.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
***Note: for speculative execution to work, user must put all "gen" in "assistant". Show in "assistant" the desired answer format. Each "gen" term should have a stop token. The stream mode is not supported in speculative execution.
|
||||||
|
E.g.
|
||||||
|
correct:
|
||||||
|
sgl.assistant("\nName:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n"))
|
||||||
|
incorrect:
|
||||||
|
s += sgl.assistant("\nName:" + sgl.gen("name", stop="\n"))
|
||||||
|
s += sgl.assistant("\nBirthday:" + sgl.gen("birthday", stop="\n"))
|
||||||
|
s += sgl.assistant("\nJob:" + sgl.gen("job", stop="\n"))
|
||||||
|
|
||||||
|
export OPENAI_API_KEY=sk-******
|
||||||
|
python3 openaichat_speculative.py
|
||||||
|
"""
|
||||||
|
import sglang as sgl
|
||||||
|
from sglang import function, gen, set_default_backend, OpenAI
|
||||||
|
|
||||||
|
|
||||||
|
@function(api_num_spec_tokens=512)
|
||||||
|
def gen_character_spec(s):
|
||||||
|
s += sgl.system("You are a helpful assistant.")
|
||||||
|
s += sgl.user("Construct a character within the following format:")
|
||||||
|
s += sgl.assistant("Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n")
|
||||||
|
s += sgl.user("Please generate new Name, Birthday and Job.\n")
|
||||||
|
s += sgl.assistant("Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n"))
|
||||||
|
|
||||||
|
|
||||||
|
@function(api_num_spec_tokens=512)
|
||||||
|
def gen_character_spec_no_few_shot(s):
|
||||||
|
s += sgl.user("Construct a character. For each field stop with a newline\n")
|
||||||
|
s += sgl.assistant("Name:" + sgl.gen("name", stop="\n") + "\nAge:" + sgl.gen("age", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n"))
|
||||||
|
|
||||||
|
|
||||||
|
@function
|
||||||
|
def gen_character_normal(s):
|
||||||
|
s += sgl.system("You are a helpful assistant.")
|
||||||
|
s += sgl.user("What's the answer of 23 + 8?")
|
||||||
|
s += sgl.assistant(sgl.gen("answer", max_tokens=64))
|
||||||
|
|
||||||
|
|
||||||
|
@function(api_num_spec_tokens=1024)
|
||||||
|
def multi_turn_question(s, question_1, question_2):
|
||||||
|
s += sgl.system("You are a helpful assistant.")
|
||||||
|
s += sgl.user("Answer questions in the following format:")
|
||||||
|
s += sgl.user("Question 1: What is the capital of France?\nQuestion 2: What is the population of this city?\n")
|
||||||
|
s += sgl.assistant("Answer 1: The capital of France is Paris.\nAnswer 2: The population of Paris in 2024 is estimated to be around 2.1 million for the city proper.\n")
|
||||||
|
s += sgl.user("Question 1: "+question_1+"\nQuestion 2: "+question_2)
|
||||||
|
s += sgl.assistant("Answer 1: "+sgl.gen("answer_1", stop="\n") + "\nAnswer 2: "+ sgl.gen("answer_2", stop="\n"))
|
||||||
|
|
||||||
|
|
||||||
|
def test_spec_single_turn():
|
||||||
|
state = gen_character_spec.run()
|
||||||
|
for m in state.messages():
|
||||||
|
print(m["role"], ":", m["content"])
|
||||||
|
|
||||||
|
print("\n-- name:", state["name"])
|
||||||
|
print("\n-- birthday:", state["birthday"])
|
||||||
|
print("\n-- job:", state["job"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_inaccurate_spec_single_turn():
|
||||||
|
state = gen_character_spec_no_few_shot.run()
|
||||||
|
for m in state.messages():
|
||||||
|
print(m["role"], ":", m["content"])
|
||||||
|
|
||||||
|
print("\n-- name:", state["name"])
|
||||||
|
print("\n-- age:", state["age"])
|
||||||
|
print("\n-- job:", state["job"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_normal_single_turn():
|
||||||
|
state = gen_character_normal.run()
|
||||||
|
for m in state.messages():
|
||||||
|
print(m["role"], ":", m["content"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_spec_multi_turn():
|
||||||
|
state = multi_turn_question.run(
|
||||||
|
question_1="What is the capital of the United States?",
|
||||||
|
question_2="List two local attractions in the capital of the United States.",
|
||||||
|
)
|
||||||
|
|
||||||
|
for m in state.messages():
|
||||||
|
print(m["role"], ":", m["content"])
|
||||||
|
|
||||||
|
print("\n-- answer_1 --\n", state["answer_1"])
|
||||||
|
print("\n-- answer_2 --\n", state["answer_2"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_spec_multi_turn_stream():
|
||||||
|
state = multi_turn_question.run(
|
||||||
|
question_1="What is the capital of the United States?",
|
||||||
|
question_2="List two local attractions.",
|
||||||
|
stream=True
|
||||||
|
)
|
||||||
|
|
||||||
|
for out in state.text_iter():
|
||||||
|
print(out, end="", flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
set_default_backend(OpenAI("gpt-4-turbo"))
|
||||||
|
|
||||||
|
print("\n========== test spec single turn ==========\n")
|
||||||
|
# expect reasonable answer for each field
|
||||||
|
test_spec_single_turn()
|
||||||
|
|
||||||
|
print("\n========== test inaccurate spec single turn ==========\n")
|
||||||
|
# expect incomplete or unreasonable answers
|
||||||
|
test_inaccurate_spec_single_turn()
|
||||||
|
|
||||||
|
print("\n========== test normal single turn ==========\n")
|
||||||
|
# expect reasonable answer
|
||||||
|
test_normal_single_turn()
|
||||||
|
|
||||||
|
print("\n========== test spec multi turn ==========\n")
|
||||||
|
# expect answer with same format as in the few shot
|
||||||
|
test_spec_multi_turn()
|
||||||
|
|
||||||
|
print("\n========== test spec multi turn stream ==========\n")
|
||||||
|
# expect error in stream_executor: stream is not supported...
|
||||||
|
test_spec_multi_turn_stream()
|
||||||
|
|
||||||
@@ -9,6 +9,7 @@ class BaseBackend:
|
|||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.support_concate_and_append = False
|
self.support_concate_and_append = False
|
||||||
self.chat_template = get_chat_template("default")
|
self.chat_template = get_chat_template("default")
|
||||||
|
self.api_num_spec_tokens = None
|
||||||
|
|
||||||
def get_model_name(self):
|
def get_model_name(self):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
import warnings
|
||||||
from typing import Callable, List, Optional, Union
|
from typing import Callable, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -80,7 +81,15 @@ class OpenAI(BaseBackend):
|
|||||||
else:
|
else:
|
||||||
self.is_chat_model = True
|
self.is_chat_model = True
|
||||||
|
|
||||||
self.chat_begin_str = self.chat_template.role_prefix_and_suffix["assistant"][0]
|
self.chat_prefix = self.chat_template.role_prefix_and_suffix["assistant"][0]
|
||||||
|
|
||||||
|
self.spec_kwargs = {}
|
||||||
|
self.spec_format = []
|
||||||
|
self.spec_max_num_tries = 3
|
||||||
|
self.api_num_spec_tokens = None
|
||||||
|
|
||||||
|
def set_api_num_spec_tokens(self, num):
|
||||||
|
self.api_num_spec_tokens = num
|
||||||
|
|
||||||
def get_chat_template(self):
|
def get_chat_template(self):
|
||||||
return self.chat_template
|
return self.chat_template
|
||||||
@@ -89,15 +98,45 @@ class OpenAI(BaseBackend):
|
|||||||
self,
|
self,
|
||||||
s: StreamExecutor,
|
s: StreamExecutor,
|
||||||
sampling_params: SglSamplingParams,
|
sampling_params: SglSamplingParams,
|
||||||
|
name=None,
|
||||||
):
|
):
|
||||||
if sampling_params.dtype is None:
|
if sampling_params.dtype is None:
|
||||||
if self.is_chat_model:
|
if self.is_chat_model:
|
||||||
if not s.text_.endswith(self.chat_begin_str):
|
if self.api_num_spec_tokens is None:
|
||||||
raise RuntimeError(
|
if not s.text_.endswith(self.chat_prefix):
|
||||||
"This use case is not supported. "
|
raise RuntimeError(
|
||||||
"For OpenAI chat models, sgl.gen must be right after sgl.assistant"
|
"This use case is not supported if api speculative execution is off. "
|
||||||
|
"For OpenAI chat models, sgl.gen must be right after sgl.assistant."
|
||||||
|
"Example of adding api speculative execution: @function(api_num_spec_tokens=128)."
|
||||||
|
)
|
||||||
|
prompt = s.messages_
|
||||||
|
else:
|
||||||
|
# collect assistant answer format
|
||||||
|
if "max_tokens" not in self.spec_kwargs:
|
||||||
|
self.spec_kwargs["max_tokens"] = self.api_num_spec_tokens
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
self.spec_kwargs["max_tokens"] == self.api_num_spec_tokens
|
||||||
|
)
|
||||||
|
params = sampling_params.to_openai_kwargs()
|
||||||
|
for key, value in params.items():
|
||||||
|
if key in ["stop"]:
|
||||||
|
continue
|
||||||
|
if key in ["max_tokens"]:
|
||||||
|
warnings.warn(
|
||||||
|
"The parameter max_tokens will be overwritten by speculated number of tokens."
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
if key not in self.spec_kwargs:
|
||||||
|
self.spec_kwargs[key] = value
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
value == self.spec_kwargs[key]
|
||||||
|
), "sampling parameters should be consistent if turn on api speculative execution."
|
||||||
|
self.spec_format.append(
|
||||||
|
{"text": "", "stop": params["stop"], "name": name}
|
||||||
)
|
)
|
||||||
prompt = s.messages_
|
return "", {}
|
||||||
else:
|
else:
|
||||||
prompt = s.text_
|
prompt = s.text_
|
||||||
|
|
||||||
@@ -110,6 +149,9 @@ class OpenAI(BaseBackend):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
elif sampling_params.dtype in [str, "str", "string"]:
|
elif sampling_params.dtype in [str, "str", "string"]:
|
||||||
|
assert (
|
||||||
|
not self.is_chat_model
|
||||||
|
), "constrained type not supported on chat model"
|
||||||
kwargs = sampling_params.to_openai_kwargs()
|
kwargs = sampling_params.to_openai_kwargs()
|
||||||
kwargs.pop("stop")
|
kwargs.pop("stop")
|
||||||
comp = openai_completion(
|
comp = openai_completion(
|
||||||
@@ -122,6 +164,9 @@ class OpenAI(BaseBackend):
|
|||||||
)
|
)
|
||||||
comp = '"' + comp + '"'
|
comp = '"' + comp + '"'
|
||||||
elif sampling_params.dtype in [int, "int"]:
|
elif sampling_params.dtype in [int, "int"]:
|
||||||
|
assert (
|
||||||
|
not self.is_chat_model
|
||||||
|
), "constrained type not supported on chat model"
|
||||||
kwargs = sampling_params.to_openai_kwargs()
|
kwargs = sampling_params.to_openai_kwargs()
|
||||||
kwargs.pop("stop")
|
kwargs.pop("stop")
|
||||||
comp = openai_completion(
|
comp = openai_completion(
|
||||||
@@ -138,6 +183,62 @@ class OpenAI(BaseBackend):
|
|||||||
|
|
||||||
return comp, {}
|
return comp, {}
|
||||||
|
|
||||||
|
def spec_fill(self, value: str):
|
||||||
|
assert self.is_chat_model
|
||||||
|
self.spec_format.append({"text": value, "stop": None, "name": None})
|
||||||
|
|
||||||
|
def spec_pattern_match(self, comp):
|
||||||
|
for i, term in enumerate(self.spec_format):
|
||||||
|
text = term["text"]
|
||||||
|
if text != "":
|
||||||
|
if comp.startswith(text):
|
||||||
|
comp = comp[len(text) :]
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
pos = comp.find(term["stop"])
|
||||||
|
if pos != -1:
|
||||||
|
term["text"] = comp[:pos]
|
||||||
|
comp = comp[pos:]
|
||||||
|
else:
|
||||||
|
if i == len(self.spec_format) - 1:
|
||||||
|
term["text"] = comp
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def role_end_generate(
|
||||||
|
self,
|
||||||
|
s: StreamExecutor,
|
||||||
|
):
|
||||||
|
if self.api_num_spec_tokens is None or not s.text_.endswith(self.chat_prefix):
|
||||||
|
return
|
||||||
|
|
||||||
|
comp = ""
|
||||||
|
if not all(x["name"] is None for x in self.spec_format):
|
||||||
|
for i in range(self.spec_max_num_tries):
|
||||||
|
comp = openai_completion(
|
||||||
|
client=self.client,
|
||||||
|
is_chat=self.is_chat_model,
|
||||||
|
model=self.model_name,
|
||||||
|
prompt=s.messages_,
|
||||||
|
**self.spec_kwargs,
|
||||||
|
)
|
||||||
|
if self.spec_pattern_match(comp):
|
||||||
|
break
|
||||||
|
|
||||||
|
for term in self.spec_format:
|
||||||
|
stop = term["stop"] if term["stop"] is not None else ""
|
||||||
|
s.text_ += term["text"]
|
||||||
|
name = term["name"]
|
||||||
|
if name is not None:
|
||||||
|
s.variables[name] = term["text"]
|
||||||
|
s.meta_info[name] = {}
|
||||||
|
s.variable_event[name].set()
|
||||||
|
|
||||||
|
self.spec_kwargs = {}
|
||||||
|
self.spec_format = []
|
||||||
|
|
||||||
def generate_stream(
|
def generate_stream(
|
||||||
self,
|
self,
|
||||||
s: StreamExecutor,
|
s: StreamExecutor,
|
||||||
@@ -145,7 +246,7 @@ class OpenAI(BaseBackend):
|
|||||||
):
|
):
|
||||||
if sampling_params.dtype is None:
|
if sampling_params.dtype is None:
|
||||||
if self.is_chat_model:
|
if self.is_chat_model:
|
||||||
if not s.text_.endswith(self.chat_begin_str):
|
if not s.text_.endswith(self.chat_prefix):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"This use case is not supported. "
|
"This use case is not supported. "
|
||||||
"For OpenAI chat models, sgl.gen must be right after sgl.assistant"
|
"For OpenAI chat models, sgl.gen must be right after sgl.assistant"
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import multiprocessing
|
|||||||
import queue
|
import queue
|
||||||
import threading
|
import threading
|
||||||
import uuid
|
import uuid
|
||||||
|
import warnings
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Callable, Dict, List, Optional, Union
|
from typing import Any, Callable, Dict, List, Optional, Union
|
||||||
@@ -185,7 +186,6 @@ class StreamExecutor:
|
|||||||
self.arguments: Dict[str, Any] = arguments
|
self.arguments: Dict[str, Any] = arguments
|
||||||
self.default_sampling_para = default_sampling_para
|
self.default_sampling_para = default_sampling_para
|
||||||
self.stream = stream
|
self.stream = stream
|
||||||
self.api_num_spec_tokens = api_num_spec_tokens
|
|
||||||
|
|
||||||
self.variables = {} # Dict[name: str -> value: str]
|
self.variables = {} # Dict[name: str -> value: str]
|
||||||
self.variable_event = {} # Dict[name: str -> event: threading.Event]
|
self.variable_event = {} # Dict[name: str -> event: threading.Event]
|
||||||
@@ -197,6 +197,9 @@ class StreamExecutor:
|
|||||||
self.text_ = "" # The full text
|
self.text_ = "" # The full text
|
||||||
|
|
||||||
# For speculative execution
|
# For speculative execution
|
||||||
|
from sglang.backend.openai import OpenAI
|
||||||
|
if isinstance(backend, OpenAI):
|
||||||
|
self.backend.set_api_num_spec_tokens(api_num_spec_tokens)
|
||||||
self.speculated_text = ""
|
self.speculated_text = ""
|
||||||
|
|
||||||
# For chat
|
# For chat
|
||||||
@@ -322,7 +325,7 @@ class StreamExecutor:
|
|||||||
try:
|
try:
|
||||||
self._execute(expr)
|
self._execute(expr)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# print(f"Error in stream_executor: {get_exception_traceback()}")
|
warnings.warn(f"Error in stream_executor: {get_exception_traceback()}")
|
||||||
error = e
|
error = e
|
||||||
break
|
break
|
||||||
self.queue.task_done()
|
self.queue.task_done()
|
||||||
@@ -391,12 +394,23 @@ class StreamExecutor:
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown type: {type(other)}")
|
raise ValueError(f"Unknown type: {type(other)}")
|
||||||
|
|
||||||
def _execute_fill(self, value: str):
|
def _execute_fill(self, value: str, prefix=False):
|
||||||
value = str(value)
|
value = str(value)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.cur_role == "assistant"
|
||||||
|
and self.backend.api_num_spec_tokens is not None
|
||||||
|
and self.backend.is_chat_model
|
||||||
|
and not prefix
|
||||||
|
):
|
||||||
|
self.backend.spec_fill(value)
|
||||||
|
return
|
||||||
|
|
||||||
if self.speculated_text.startswith(value):
|
if self.speculated_text.startswith(value):
|
||||||
self.speculated_text = self.speculated_text[len(value) :]
|
self.speculated_text = self.speculated_text[len(value) :]
|
||||||
else:
|
else:
|
||||||
self.speculated_text = ""
|
self.speculated_text = ""
|
||||||
|
|
||||||
self.text_ += value
|
self.text_ += value
|
||||||
|
|
||||||
def _execute_image(self, expr: SglImage):
|
def _execute_image(self, expr: SglImage):
|
||||||
@@ -426,14 +440,29 @@ class StreamExecutor:
|
|||||||
name = expr.name
|
name = expr.name
|
||||||
|
|
||||||
if not self.stream:
|
if not self.stream:
|
||||||
if self.api_num_spec_tokens is not None:
|
if self.backend.api_num_spec_tokens is None:
|
||||||
|
comp, meta_info = self.backend.generate(
|
||||||
|
self,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif self.backend.is_chat_model:
|
||||||
|
# spec on model with only chat interface
|
||||||
|
comp, meta_info = self.backend.generate(
|
||||||
|
self,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
name=name,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
else: # spec on model with completion
|
||||||
stop = sampling_params.stop
|
stop = sampling_params.stop
|
||||||
max_new_tokens = sampling_params.max_new_tokens
|
max_new_tokens = sampling_params.max_new_tokens
|
||||||
meta_info = {}
|
meta_info = {}
|
||||||
|
|
||||||
def regen():
|
def regen():
|
||||||
sampling_params.max_new_tokens = max(
|
sampling_params.max_new_tokens = max(
|
||||||
sampling_params.max_new_tokens, self.api_num_spec_tokens
|
sampling_params.max_new_tokens, self.backend.api_num_spec_tokens
|
||||||
)
|
)
|
||||||
sampling_params.stop = None
|
sampling_params.stop = None
|
||||||
self.speculated_text, meta_info = self.backend.generate(
|
self.speculated_text, meta_info = self.backend.generate(
|
||||||
@@ -442,16 +471,14 @@ class StreamExecutor:
|
|||||||
|
|
||||||
def find_stop():
|
def find_stop():
|
||||||
if isinstance(stop, str):
|
if isinstance(stop, str):
|
||||||
return self.speculated_text.find(stop), len(stop)
|
return self.speculated_text.find(stop)
|
||||||
elif isinstance(stop, (tuple, list)):
|
elif isinstance(stop, (tuple, list)):
|
||||||
pos = -1
|
pos = -1
|
||||||
stop_len = 0
|
|
||||||
for stop_str in stop:
|
for stop_str in stop:
|
||||||
stop_pos = self.speculated_text.find(stop_str)
|
stop_pos = self.speculated_text.find(stop_str)
|
||||||
if stop_pos != -1 and (pos == -1 or stop_pos < pos):
|
if stop_pos != -1 and (pos == -1 or stop_pos < pos):
|
||||||
pos = stop_pos
|
pos = stop_pos
|
||||||
stop_len = len(stop_str)
|
return pos
|
||||||
return pos, stop_len
|
|
||||||
else:
|
else:
|
||||||
raise Exception("Wrong type of stop in sampling parameters.")
|
raise Exception("Wrong type of stop in sampling parameters.")
|
||||||
|
|
||||||
@@ -463,23 +490,16 @@ class StreamExecutor:
|
|||||||
elif isinstance(stop, (str, list, tuple)):
|
elif isinstance(stop, (str, list, tuple)):
|
||||||
if self.speculated_text == "":
|
if self.speculated_text == "":
|
||||||
regen()
|
regen()
|
||||||
stop_pos, stop_len = find_stop()
|
stop_pos = find_stop()
|
||||||
if stop_pos == -1:
|
if stop_pos == -1:
|
||||||
stop_pos, stop_len = (
|
stop_pos = min(
|
||||||
min(
|
sampling_params.max_new_tokens,
|
||||||
sampling_params.max_new_tokens,
|
len(self.speculated_text),
|
||||||
len(self.speculated_text),
|
|
||||||
),
|
|
||||||
0,
|
|
||||||
)
|
)
|
||||||
comp = self.speculated_text[:stop_pos]
|
comp = self.speculated_text[:stop_pos]
|
||||||
self.speculated_text = self.speculated_text[stop_pos:]
|
self.speculated_text = self.speculated_text[stop_pos:]
|
||||||
else:
|
else:
|
||||||
raise ValueError("Wrong type of stop in sampling parameters.")
|
raise ValueError("Wrong type of stop in sampling parameters.")
|
||||||
else:
|
|
||||||
comp, meta_info = self.backend.generate(
|
|
||||||
self, sampling_params=sampling_params
|
|
||||||
)
|
|
||||||
|
|
||||||
self.text_ += comp
|
self.text_ += comp
|
||||||
|
|
||||||
@@ -487,6 +507,9 @@ class StreamExecutor:
|
|||||||
self.meta_info[name] = meta_info
|
self.meta_info[name] = meta_info
|
||||||
self.variable_event[name].set()
|
self.variable_event[name].set()
|
||||||
else:
|
else:
|
||||||
|
assert (
|
||||||
|
self.backend.api_num_spec_tokens is None
|
||||||
|
), "stream is not supported with api speculative execution"
|
||||||
generator = self.backend.generate_stream(
|
generator = self.backend.generate_stream(
|
||||||
self, sampling_params=sampling_params
|
self, sampling_params=sampling_params
|
||||||
)
|
)
|
||||||
@@ -542,10 +565,18 @@ class StreamExecutor:
|
|||||||
|
|
||||||
prefix, _ = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_)
|
prefix, _ = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_)
|
||||||
|
|
||||||
self._execute_fill(prefix)
|
self._execute_fill(prefix, prefix=True)
|
||||||
self.cur_role_begin_pos = len(self.text_)
|
self.cur_role_begin_pos = len(self.text_)
|
||||||
|
|
||||||
def _execute_role_end(self, expr: SglRoleEnd):
|
def _execute_role_end(self, expr: SglRoleEnd):
|
||||||
|
if (
|
||||||
|
self.cur_role == "assistant"
|
||||||
|
and self.backend.api_num_spec_tokens is not None
|
||||||
|
and self.backend.is_chat_model
|
||||||
|
):
|
||||||
|
self.backend.role_end_generate(self)
|
||||||
|
self.cur_role = None
|
||||||
|
|
||||||
new_text = self.text_[self.cur_role_begin_pos :].lstrip()
|
new_text = self.text_[self.cur_role_begin_pos :].lstrip()
|
||||||
|
|
||||||
_, suffix = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_)
|
_, suffix = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_)
|
||||||
@@ -572,8 +603,6 @@ class StreamExecutor:
|
|||||||
# OpenAI chat API format
|
# OpenAI chat API format
|
||||||
self.messages_.append({"role": expr.role, "content": new_text})
|
self.messages_.append({"role": expr.role, "content": new_text})
|
||||||
|
|
||||||
self.cur_role = None
|
|
||||||
|
|
||||||
def _execute_var_scope_begin(self, expr: SglVarScopeBegin):
|
def _execute_var_scope_begin(self, expr: SglVarScopeBegin):
|
||||||
self.variables[expr.name] = int(len(self.text_))
|
self.variables[expr.name] = int(len(self.text_))
|
||||||
|
|
||||||
|
|||||||
@@ -31,8 +31,9 @@ class GenerateReqInput:
|
|||||||
|
|
||||||
def post_init(self):
|
def post_init(self):
|
||||||
|
|
||||||
if ((self.text is None and self.input_ids is None) or
|
if (self.text is None and self.input_ids is None) or (
|
||||||
(self.text is not None and self.input_ids is not None)):
|
self.text is not None and self.input_ids is not None
|
||||||
|
):
|
||||||
raise ValueError("Either text or input_ids should be provided.")
|
raise ValueError("Either text or input_ids should be provided.")
|
||||||
|
|
||||||
if self.text is not None:
|
if self.text is not None:
|
||||||
|
|||||||
@@ -38,7 +38,6 @@ from sglang.srt.utils import (
|
|||||||
)
|
)
|
||||||
from sglang.utils import get_exception_traceback
|
from sglang.utils import get_exception_traceback
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger("model_rpc")
|
logger = logging.getLogger("model_rpc")
|
||||||
vllm_default_logger.setLevel(logging.WARN)
|
vllm_default_logger.setLevel(logging.WARN)
|
||||||
logging.getLogger("vllm.utils").setLevel(logging.WARN)
|
logging.getLogger("vllm.utils").setLevel(logging.WARN)
|
||||||
|
|||||||
@@ -341,7 +341,6 @@ class TokenizerManager:
|
|||||||
return top_logprobs
|
return top_logprobs
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
global global_processor
|
global global_processor
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -9,8 +9,8 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from typing import List, Optional, Union
|
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
# Fix a bug of Python threading
|
# Fix a bug of Python threading
|
||||||
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
||||||
@@ -45,7 +45,6 @@ from sglang.srt.utils import (
|
|||||||
)
|
)
|
||||||
from sglang.utils import get_exception_traceback
|
from sglang.utils import get_exception_traceback
|
||||||
|
|
||||||
|
|
||||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||||
|
|
||||||
|
|
||||||
@@ -84,6 +83,7 @@ async def flush_cache():
|
|||||||
|
|
||||||
async def generate_request(obj: GenerateReqInput, request: Request):
|
async def generate_request(obj: GenerateReqInput, request: Request):
|
||||||
if obj.stream:
|
if obj.stream:
|
||||||
|
|
||||||
async def stream_results():
|
async def stream_results():
|
||||||
try:
|
try:
|
||||||
async for out in tokenizer_manager.generate_request(obj, request):
|
async for out in tokenizer_manager.generate_request(obj, request):
|
||||||
@@ -99,8 +99,10 @@ async def generate_request(obj: GenerateReqInput, request: Request):
|
|||||||
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
||||||
return ret
|
return ret
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return JSONResponse({"error": {"message": str(e)}},
|
return JSONResponse(
|
||||||
status_code=HTTPStatus.BAD_REQUEST)
|
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
app.post("/generate")(generate_request)
|
app.post("/generate")(generate_request)
|
||||||
app.put("/generate")(generate_request)
|
app.put("/generate")(generate_request)
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ from packaging import version as pkg_version
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -157,7 +156,9 @@ def allocate_init_ports(
|
|||||||
cur_port += 1
|
cur_port += 1
|
||||||
|
|
||||||
if port and ret_ports[0] != port:
|
if port and ret_ports[0] != port:
|
||||||
logger.warn(f"WARNING: Port {port} is not available. Use port {ret_ports[0]} instead.")
|
logger.warn(
|
||||||
|
f"WARNING: Port {port} is not available. Use port {ret_ports[0]} instead."
|
||||||
|
)
|
||||||
|
|
||||||
return ret_ports[0], ret_ports[1:]
|
return ret_ports[0], ret_ports[1:]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user