Rename api_num_spec_tokens -> num_api_spec_tokens (#458)
This commit is contained in:
@@ -19,7 +19,7 @@ import sglang as sgl
|
||||
from sglang import function, set_default_backend, OpenAI
|
||||
|
||||
|
||||
@function(api_num_spec_tokens=256)
|
||||
@function(num_api_spec_tokens=256)
|
||||
def gen_character_spec(s):
|
||||
s += sgl.system("You are a helpful assistant.")
|
||||
s += sgl.user("Construct a character within the following format:")
|
||||
@@ -28,7 +28,7 @@ def gen_character_spec(s):
|
||||
s += sgl.assistant("Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n"))
|
||||
|
||||
|
||||
@function(api_num_spec_tokens=256)
|
||||
@function(num_api_spec_tokens=256)
|
||||
def gen_character_spec_no_few_shot(s):
|
||||
s += sgl.user("Construct a character. For each field stop with a newline\n")
|
||||
s += sgl.assistant("Name:" + sgl.gen("name", stop="\n") + "\nAge:" + sgl.gen("age", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n"))
|
||||
@@ -41,7 +41,7 @@ def gen_character_normal(s):
|
||||
s += sgl.assistant(sgl.gen("answer", max_tokens=64))
|
||||
|
||||
|
||||
@function(api_num_spec_tokens=1024)
|
||||
@function(num_api_spec_tokens=1024)
|
||||
def multi_turn_question(s, question_1, question_2):
|
||||
s += sgl.system("You are a helpful assistant.")
|
||||
s += sgl.user("Answer questions in the following format:")
|
||||
|
||||
@@ -5,7 +5,7 @@ python3 openai_speculative.py
|
||||
from sglang import function, gen, set_default_backend, OpenAI
|
||||
|
||||
|
||||
@function(api_num_spec_tokens=64)
|
||||
@function(num_api_spec_tokens=64)
|
||||
def gen_character_spec(s):
|
||||
s += "Construct a character within the following format:\n"
|
||||
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
|
||||
@@ -23,7 +23,7 @@ def gen_character_no_spec(s):
|
||||
s += "\nJob:" + gen("job", stop="\n") + "\n"
|
||||
|
||||
|
||||
@function(api_num_spec_tokens=64)
|
||||
@function(num_api_spec_tokens=64)
|
||||
def gen_character_spec_no_few_shot(s):
|
||||
# s += "Construct a character with name, birthday, and job:\n"
|
||||
s += "Construct a character:\n"
|
||||
|
||||
@@ -20,13 +20,13 @@ from sglang.lang.ir import (
|
||||
|
||||
|
||||
def function(
|
||||
func: Optional[Callable] = None, api_num_spec_tokens: Optional[int] = None
|
||||
func: Optional[Callable] = None, num_api_spec_tokens: Optional[int] = None
|
||||
):
|
||||
if func:
|
||||
return SglFunction(func, api_num_spec_tokens=api_num_spec_tokens)
|
||||
return SglFunction(func, num_api_spec_tokens=num_api_spec_tokens)
|
||||
|
||||
def decorator(func):
|
||||
return SglFunction(func, api_num_spec_tokens=api_num_spec_tokens)
|
||||
return SglFunction(func, num_api_spec_tokens=num_api_spec_tokens)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@@ -106,12 +106,12 @@ class OpenAI(BaseBackend):
|
||||
return self.chat_template
|
||||
|
||||
def _prepare_spec_execution(self, sampling_params: SglSamplingParams,
|
||||
api_num_spec_tokens: int, spec_var_name: str):
|
||||
num_api_spec_tokens: int, spec_var_name: str):
|
||||
if "max_tokens" not in self.spec_kwargs:
|
||||
self.spec_kwargs["max_tokens"] = api_num_spec_tokens
|
||||
self.spec_kwargs["max_tokens"] = num_api_spec_tokens
|
||||
else:
|
||||
assert (
|
||||
self.spec_kwargs["max_tokens"] == api_num_spec_tokens
|
||||
self.spec_kwargs["max_tokens"] == num_api_spec_tokens
|
||||
)
|
||||
|
||||
params = sampling_params.to_openai_kwargs()
|
||||
@@ -142,17 +142,17 @@ class OpenAI(BaseBackend):
|
||||
):
|
||||
if sampling_params.dtype is None:
|
||||
if self.is_chat_model:
|
||||
if s.api_num_spec_tokens is None:
|
||||
if s.num_api_spec_tokens is None:
|
||||
if not s.text_.endswith(self.chat_prefix):
|
||||
raise RuntimeError(
|
||||
"This use case is not supported if api speculative execution is off. "
|
||||
"For OpenAI chat models, sgl.gen must be right after sgl.assistant. "
|
||||
"Example of adding api speculative execution: @function(api_num_spec_tokens=128)."
|
||||
"Example of adding api speculative execution: @function(num_api_spec_tokens=128)."
|
||||
)
|
||||
prompt = s.messages_
|
||||
else:
|
||||
return self._prepare_spec_execution(sampling_params,
|
||||
s.api_num_spec_tokens, spec_var_name)
|
||||
s.num_api_spec_tokens, spec_var_name)
|
||||
else:
|
||||
prompt = s.text_
|
||||
|
||||
@@ -230,7 +230,7 @@ class OpenAI(BaseBackend):
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
):
|
||||
if s.api_num_spec_tokens is None or not s.text_.endswith(self.chat_prefix):
|
||||
if s.num_api_spec_tokens is None or not s.text_.endswith(self.chat_prefix):
|
||||
return
|
||||
|
||||
comp = ""
|
||||
|
||||
@@ -66,7 +66,7 @@ def run_program(
|
||||
default_sampling_para,
|
||||
chat_template=None,
|
||||
stream=stream,
|
||||
api_num_spec_tokens=program.api_num_spec_tokens,
|
||||
num_api_spec_tokens=program.num_api_spec_tokens,
|
||||
)
|
||||
state = ProgramState(stream_executor)
|
||||
|
||||
@@ -178,7 +178,7 @@ class StreamExecutor:
|
||||
default_sampling_para,
|
||||
chat_template,
|
||||
stream,
|
||||
api_num_spec_tokens=None,
|
||||
num_api_spec_tokens=None,
|
||||
use_thread=True,
|
||||
):
|
||||
self.sid = uuid.uuid4().hex
|
||||
@@ -210,7 +210,7 @@ class StreamExecutor:
|
||||
self.fork_start_text_pos = None
|
||||
|
||||
# For speculative execution
|
||||
self.api_num_spec_tokens = api_num_spec_tokens
|
||||
self.num_api_spec_tokens = num_api_spec_tokens
|
||||
self.speculated_text = ""
|
||||
|
||||
# Worker thread
|
||||
@@ -399,7 +399,7 @@ class StreamExecutor:
|
||||
|
||||
if (
|
||||
self.cur_role == "assistant"
|
||||
and self.api_num_spec_tokens is not None
|
||||
and self.num_api_spec_tokens is not None
|
||||
and self.backend.is_chat_model
|
||||
and not prefix
|
||||
):
|
||||
@@ -444,7 +444,7 @@ class StreamExecutor:
|
||||
nonlocal meta_info
|
||||
|
||||
sampling_params.max_new_tokens = max(
|
||||
sampling_params.max_new_tokens, self.api_num_spec_tokens
|
||||
sampling_params.max_new_tokens, self.num_api_spec_tokens
|
||||
)
|
||||
sampling_params.stop = None
|
||||
self.speculated_text, meta_info = self.backend.generate(
|
||||
@@ -490,7 +490,7 @@ class StreamExecutor:
|
||||
name = expr.name
|
||||
|
||||
if not self.stream:
|
||||
if self.api_num_spec_tokens is None:
|
||||
if self.num_api_spec_tokens is None:
|
||||
comp, meta_info = self.backend.generate(
|
||||
self,
|
||||
sampling_params=sampling_params,
|
||||
@@ -517,7 +517,7 @@ class StreamExecutor:
|
||||
self.variable_event[name].set()
|
||||
else:
|
||||
assert (
|
||||
self.api_num_spec_tokens is None
|
||||
self.num_api_spec_tokens is None
|
||||
), "stream is not supported with api speculative execution"
|
||||
generator = self.backend.generate_stream(
|
||||
self, sampling_params=sampling_params
|
||||
@@ -580,7 +580,7 @@ class StreamExecutor:
|
||||
def _execute_role_end(self, expr: SglRoleEnd):
|
||||
if (
|
||||
self.cur_role == "assistant"
|
||||
and self.api_num_spec_tokens is not None
|
||||
and self.num_api_spec_tokens is not None
|
||||
and self.backend.is_chat_model
|
||||
):
|
||||
# Execute the stored lazy generation calls
|
||||
|
||||
@@ -97,9 +97,9 @@ class SglSamplingParams:
|
||||
|
||||
|
||||
class SglFunction:
|
||||
def __init__(self, func, api_num_spec_tokens=None, bind_arguments=None):
|
||||
def __init__(self, func, num_api_spec_tokens=None, bind_arguments=None):
|
||||
self.func = func
|
||||
self.api_num_spec_tokens = api_num_spec_tokens
|
||||
self.num_api_spec_tokens = num_api_spec_tokens
|
||||
self.bind_arguments = bind_arguments or {}
|
||||
self.pin_prefix_rid = None
|
||||
|
||||
|
||||
@@ -353,7 +353,7 @@ def test_regex():
|
||||
|
||||
|
||||
def test_completion_speculative():
|
||||
@sgl.function(api_num_spec_tokens=64)
|
||||
@sgl.function(num_api_spec_tokens=64)
|
||||
def gen_character_spec(s):
|
||||
s += "Construct a character within the following format:\n"
|
||||
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
|
||||
@@ -384,7 +384,7 @@ def test_completion_speculative():
|
||||
|
||||
|
||||
def test_chat_completion_speculative():
|
||||
@sgl.function(api_num_spec_tokens=256)
|
||||
@sgl.function(num_api_spec_tokens=256)
|
||||
def gen_character_spec(s):
|
||||
s += sgl.system("You are a helpful assistant.")
|
||||
s += sgl.user("Construct a character within the following format:")
|
||||
|
||||
Reference in New Issue
Block a user