diff --git a/examples/usage/openai_chat_speculative.py b/examples/usage/openai_chat_speculative.py index c3f38a104..94eb43276 100644 --- a/examples/usage/openai_chat_speculative.py +++ b/examples/usage/openai_chat_speculative.py @@ -19,7 +19,7 @@ import sglang as sgl from sglang import function, set_default_backend, OpenAI -@function(api_num_spec_tokens=256) +@function(num_api_spec_tokens=256) def gen_character_spec(s): s += sgl.system("You are a helpful assistant.") s += sgl.user("Construct a character within the following format:") @@ -28,7 +28,7 @@ def gen_character_spec(s): s += sgl.assistant("Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n")) -@function(api_num_spec_tokens=256) +@function(num_api_spec_tokens=256) def gen_character_spec_no_few_shot(s): s += sgl.user("Construct a character. For each field stop with a newline\n") s += sgl.assistant("Name:" + sgl.gen("name", stop="\n") + "\nAge:" + sgl.gen("age", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n")) @@ -41,7 +41,7 @@ def gen_character_normal(s): s += sgl.assistant(sgl.gen("answer", max_tokens=64)) -@function(api_num_spec_tokens=1024) +@function(num_api_spec_tokens=1024) def multi_turn_question(s, question_1, question_2): s += sgl.system("You are a helpful assistant.") s += sgl.user("Answer questions in the following format:") diff --git a/examples/usage/openai_speculative.py b/examples/usage/openai_speculative.py index 24eeac815..c64694da6 100644 --- a/examples/usage/openai_speculative.py +++ b/examples/usage/openai_speculative.py @@ -5,7 +5,7 @@ python3 openai_speculative.py from sglang import function, gen, set_default_backend, OpenAI -@function(api_num_spec_tokens=64) +@function(num_api_spec_tokens=64) def gen_character_spec(s): s += "Construct a character within the following format:\n" s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n" @@ -23,7 +23,7 @@ def gen_character_no_spec(s): s += "\nJob:" + gen("job", stop="\n") + "\n" -@function(api_num_spec_tokens=64) +@function(num_api_spec_tokens=64) def gen_character_spec_no_few_shot(s): # s += "Construct a character with name, birthday, and job:\n" s += "Construct a character:\n" diff --git a/python/sglang/api.py b/python/sglang/api.py index 4448333a6..2a935a4e0 100644 --- a/python/sglang/api.py +++ b/python/sglang/api.py @@ -20,13 +20,13 @@ from sglang.lang.ir import ( def function( - func: Optional[Callable] = None, api_num_spec_tokens: Optional[int] = None + func: Optional[Callable] = None, num_api_spec_tokens: Optional[int] = None ): if func: - return SglFunction(func, api_num_spec_tokens=api_num_spec_tokens) + return SglFunction(func, num_api_spec_tokens=num_api_spec_tokens) def decorator(func): - return SglFunction(func, api_num_spec_tokens=api_num_spec_tokens) + return SglFunction(func, num_api_spec_tokens=num_api_spec_tokens) return decorator diff --git a/python/sglang/backend/openai.py b/python/sglang/backend/openai.py index 3fd0fe80f..2cb5992d8 100644 --- a/python/sglang/backend/openai.py +++ b/python/sglang/backend/openai.py @@ -106,12 +106,12 @@ class OpenAI(BaseBackend): return self.chat_template def _prepare_spec_execution(self, sampling_params: SglSamplingParams, - api_num_spec_tokens: int, spec_var_name: str): + num_api_spec_tokens: int, spec_var_name: str): if "max_tokens" not in self.spec_kwargs: - self.spec_kwargs["max_tokens"] = api_num_spec_tokens + self.spec_kwargs["max_tokens"] = num_api_spec_tokens else: assert ( - self.spec_kwargs["max_tokens"] == api_num_spec_tokens + self.spec_kwargs["max_tokens"] == num_api_spec_tokens ) params = sampling_params.to_openai_kwargs() @@ -142,17 +142,17 @@ class OpenAI(BaseBackend): ): if sampling_params.dtype is None: if self.is_chat_model: - if s.api_num_spec_tokens is None: + if s.num_api_spec_tokens is None: if not s.text_.endswith(self.chat_prefix): raise RuntimeError( "This use case is not supported if api speculative execution is off. " "For OpenAI chat models, sgl.gen must be right after sgl.assistant. " - "Example of adding api speculative execution: @function(api_num_spec_tokens=128)." + "Example of adding api speculative execution: @function(num_api_spec_tokens=128)." ) prompt = s.messages_ else: return self._prepare_spec_execution(sampling_params, - s.api_num_spec_tokens, spec_var_name) + s.num_api_spec_tokens, spec_var_name) else: prompt = s.text_ @@ -230,7 +230,7 @@ class OpenAI(BaseBackend): self, s: StreamExecutor, ): - if s.api_num_spec_tokens is None or not s.text_.endswith(self.chat_prefix): + if s.num_api_spec_tokens is None or not s.text_.endswith(self.chat_prefix): return comp = "" diff --git a/python/sglang/lang/interpreter.py b/python/sglang/lang/interpreter.py index 279dfea26..789879f00 100644 --- a/python/sglang/lang/interpreter.py +++ b/python/sglang/lang/interpreter.py @@ -66,7 +66,7 @@ def run_program( default_sampling_para, chat_template=None, stream=stream, - api_num_spec_tokens=program.api_num_spec_tokens, + num_api_spec_tokens=program.num_api_spec_tokens, ) state = ProgramState(stream_executor) @@ -178,7 +178,7 @@ class StreamExecutor: default_sampling_para, chat_template, stream, - api_num_spec_tokens=None, + num_api_spec_tokens=None, use_thread=True, ): self.sid = uuid.uuid4().hex @@ -210,7 +210,7 @@ class StreamExecutor: self.fork_start_text_pos = None # For speculative execution - self.api_num_spec_tokens = api_num_spec_tokens + self.num_api_spec_tokens = num_api_spec_tokens self.speculated_text = "" # Worker thread @@ -399,7 +399,7 @@ class StreamExecutor: if ( self.cur_role == "assistant" - and self.api_num_spec_tokens is not None + and self.num_api_spec_tokens is not None and self.backend.is_chat_model and not prefix ): @@ -444,7 +444,7 @@ class StreamExecutor: nonlocal meta_info sampling_params.max_new_tokens = max( - sampling_params.max_new_tokens, self.api_num_spec_tokens + sampling_params.max_new_tokens, self.num_api_spec_tokens ) sampling_params.stop = None self.speculated_text, meta_info = self.backend.generate( @@ -490,7 +490,7 @@ class StreamExecutor: name = expr.name if not self.stream: - if self.api_num_spec_tokens is None: + if self.num_api_spec_tokens is None: comp, meta_info = self.backend.generate( self, sampling_params=sampling_params, @@ -517,7 +517,7 @@ class StreamExecutor: self.variable_event[name].set() else: assert ( - self.api_num_spec_tokens is None + self.num_api_spec_tokens is None ), "stream is not supported with api speculative execution" generator = self.backend.generate_stream( self, sampling_params=sampling_params @@ -580,7 +580,7 @@ class StreamExecutor: def _execute_role_end(self, expr: SglRoleEnd): if ( self.cur_role == "assistant" - and self.api_num_spec_tokens is not None + and self.num_api_spec_tokens is not None and self.backend.is_chat_model ): # Execute the stored lazy generation calls diff --git a/python/sglang/lang/ir.py b/python/sglang/lang/ir.py index 3506e6ba3..2265a0a7a 100644 --- a/python/sglang/lang/ir.py +++ b/python/sglang/lang/ir.py @@ -97,9 +97,9 @@ class SglSamplingParams: class SglFunction: - def __init__(self, func, api_num_spec_tokens=None, bind_arguments=None): + def __init__(self, func, num_api_spec_tokens=None, bind_arguments=None): self.func = func - self.api_num_spec_tokens = api_num_spec_tokens + self.num_api_spec_tokens = num_api_spec_tokens self.bind_arguments = bind_arguments or {} self.pin_prefix_rid = None diff --git a/python/sglang/test/test_programs.py b/python/sglang/test/test_programs.py index b887c09b2..2be7ecdb9 100644 --- a/python/sglang/test/test_programs.py +++ b/python/sglang/test/test_programs.py @@ -353,7 +353,7 @@ def test_regex(): def test_completion_speculative(): - @sgl.function(api_num_spec_tokens=64) + @sgl.function(num_api_spec_tokens=64) def gen_character_spec(s): s += "Construct a character within the following format:\n" s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n" @@ -384,7 +384,7 @@ def test_completion_speculative(): def test_chat_completion_speculative(): - @sgl.function(api_num_spec_tokens=256) + @sgl.function(num_api_spec_tokens=256) def gen_character_spec(s): s += sgl.system("You are a helpful assistant.") s += sgl.user("Construct a character within the following format:")