diff --git a/examples/usage/openaichat_speculative.py b/examples/usage/openai_chat_speculative.py similarity index 83% rename from examples/usage/openaichat_speculative.py rename to examples/usage/openai_chat_speculative.py index b6d9aff75..c3f38a104 100644 --- a/examples/usage/openaichat_speculative.py +++ b/examples/usage/openai_chat_speculative.py @@ -1,22 +1,25 @@ """ Usage: -***Note: for speculative execution to work, user must put all "gen" in "assistant". Show in "assistant" the desired answer format. Each "gen" term should have a stop token. The stream mode is not supported in speculative execution. +***Note: for speculative execution to work, user must put all "gen" in "assistant". +Show in "assistant" the desired answer format. Each "gen" term should have a stop token. +The stream mode is not supported in speculative execution. + E.g. correct: sgl.assistant("\nName:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n")) -incorrect: +incorrect: s += sgl.assistant("\nName:" + sgl.gen("name", stop="\n")) s += sgl.assistant("\nBirthday:" + sgl.gen("birthday", stop="\n")) s += sgl.assistant("\nJob:" + sgl.gen("job", stop="\n")) export OPENAI_API_KEY=sk-****** -python3 openaichat_speculative.py +python3 openai_chat_speculative.py """ import sglang as sgl -from sglang import function, gen, set_default_backend, OpenAI +from sglang import function, set_default_backend, OpenAI -@function(api_num_spec_tokens=512) +@function(api_num_spec_tokens=256) def gen_character_spec(s): s += sgl.system("You are a helpful assistant.") s += sgl.user("Construct a character within the following format:") @@ -25,7 +28,7 @@ def gen_character_spec(s): s += sgl.assistant("Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n")) -@function(api_num_spec_tokens=512) +@function(api_num_spec_tokens=256) def gen_character_spec_no_few_shot(s): s += sgl.user("Construct a character. For each field stop with a newline\n") s += sgl.assistant("Name:" + sgl.gen("name", stop="\n") + "\nAge:" + sgl.gen("age", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n")) @@ -44,18 +47,21 @@ def multi_turn_question(s, question_1, question_2): s += sgl.user("Answer questions in the following format:") s += sgl.user("Question 1: What is the capital of France?\nQuestion 2: What is the population of this city?\n") s += sgl.assistant("Answer 1: The capital of France is Paris.\nAnswer 2: The population of Paris in 2024 is estimated to be around 2.1 million for the city proper.\n") - s += sgl.user("Question 1: "+question_1+"\nQuestion 2: "+question_2) - s += sgl.assistant("Answer 1: "+sgl.gen("answer_1", stop="\n") + "\nAnswer 2: "+ sgl.gen("answer_2", stop="\n")) + s += sgl.user("Question 1: " + question_1+"\nQuestion 2: " + question_2) + s += sgl.assistant("Answer 1: " + sgl.gen("answer_1", stop="\n") + "\nAnswer 2: " + sgl.gen("answer_2", stop="\n")) def test_spec_single_turn(): + backend.token_usage.reset() + state = gen_character_spec.run() for m in state.messages(): print(m["role"], ":", m["content"]) print("\n-- name:", state["name"]) - print("\n-- birthday:", state["birthday"]) - print("\n-- job:", state["job"]) + print("-- birthday:", state["birthday"]) + print("-- job:", state["job"]) + print(backend.token_usage) def test_inaccurate_spec_single_turn(): @@ -99,7 +105,8 @@ def test_spec_multi_turn_stream(): if __name__ == "__main__": - set_default_backend(OpenAI("gpt-4-turbo")) + backend = OpenAI("gpt-4-turbo") + set_default_backend(backend) print("\n========== test spec single turn ==========\n") # expect reasonable answer for each field @@ -119,5 +126,4 @@ if __name__ == "__main__": print("\n========== test spec multi turn stream ==========\n") # expect error in stream_executor: stream is not supported... - test_spec_multi_turn_stream() - + test_spec_multi_turn_stream() \ No newline at end of file diff --git a/examples/usage/openai_speculative.py b/examples/usage/openai_speculative.py index 4b1d04f2a..24eeac815 100644 --- a/examples/usage/openai_speculative.py +++ b/examples/usage/openai_speculative.py @@ -5,7 +5,7 @@ python3 openai_speculative.py from sglang import function, gen, set_default_backend, OpenAI -@function(api_num_spec_tokens=512) +@function(api_num_spec_tokens=64) def gen_character_spec(s): s += "Construct a character within the following format:\n" s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n" @@ -14,6 +14,15 @@ def gen_character_spec(s): s += "\nJob:" + gen("job", stop="\n") + "\n" +@function +def gen_character_no_spec(s): + s += "Construct a character within the following format:\n" + s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n" + s += "\nPlease generate new Name, Birthday and Job.\n" + s += "Name:" + gen("name", stop="\n") + "\nBirthday:" + gen("birthday", stop="\n") + s += "\nJob:" + gen("job", stop="\n") + "\n" + + @function(api_num_spec_tokens=64) def gen_character_spec_no_few_shot(s): # s += "Construct a character with name, birthday, and job:\n" @@ -22,17 +31,19 @@ def gen_character_spec_no_few_shot(s): s += "\nJob:" + gen("job", stop="\n") + "\n" -set_default_backend(OpenAI("gpt-3.5-turbo-instruct")) +if __name__ == "__main__": + backend = OpenAI("gpt-3.5-turbo-instruct") + set_default_backend(backend) -state = gen_character_spec.run() + for function in [gen_character_spec, gen_character_no_spec, gen_character_spec_no_few_shot]: + backend.token_usage.reset() -print("...name:", state["name"]) -print("...birthday:", state["birthday"]) -print("...job:", state["job"]) + print(f"function: {function.func.__name__}") -state = gen_character_spec_no_few_shot.run() - -print("\n...name:", state["name"]) -print("...birthday:", state["birthday"]) -print("...job:", state["job"]) + state = function.run() + print("...name:", state["name"]) + print("...birthday:", state["birthday"]) + print("...job:", state["job"]) + print(backend.token_usage) + print() \ No newline at end of file diff --git a/python/sglang/backend/base_backend.py b/python/sglang/backend/base_backend.py index 716981365..cb504f51b 100644 --- a/python/sglang/backend/base_backend.py +++ b/python/sglang/backend/base_backend.py @@ -9,7 +9,6 @@ class BaseBackend: def __init__(self) -> None: self.support_concate_and_append = False self.chat_template = get_chat_template("default") - self.api_num_spec_tokens = None def get_model_name(self): raise NotImplementedError() diff --git a/python/sglang/backend/openai.py b/python/sglang/backend/openai.py index df3333523..3fd0fe80f 100644 --- a/python/sglang/backend/openai.py +++ b/python/sglang/backend/openai.py @@ -1,6 +1,7 @@ import logging import time import warnings +import dataclasses from typing import Callable, List, Optional, Union import numpy as np @@ -42,6 +43,15 @@ INSTRUCT_MODEL_NAMES = [ ] +@dataclasses.dataclass +class TokenUsage: + prompt_tokens: int + completion_tokens: int + + def reset(self): + self.prompt_tokens = self.completion_tokens = 0 + + class OpenAI(BaseBackend): def __init__( self, @@ -83,66 +93,73 @@ class OpenAI(BaseBackend): self.chat_prefix = self.chat_template.role_prefix_and_suffix["assistant"][0] + # Usage + self.token_usage = TokenUsage(0, 0) + + # API speculative execution + # TODO(ying): This does not support multi-threading (run_batch) self.spec_kwargs = {} self.spec_format = [] self.spec_max_num_tries = 3 - self.api_num_spec_tokens = None - - def set_api_num_spec_tokens(self, num): - self.api_num_spec_tokens = num def get_chat_template(self): return self.chat_template + def _prepare_spec_execution(self, sampling_params: SglSamplingParams, + api_num_spec_tokens: int, spec_var_name: str): + if "max_tokens" not in self.spec_kwargs: + self.spec_kwargs["max_tokens"] = api_num_spec_tokens + else: + assert ( + self.spec_kwargs["max_tokens"] == api_num_spec_tokens + ) + + params = sampling_params.to_openai_kwargs() + for key, value in params.items(): + if key in ["stop"]: + continue + if key in ["max_tokens"]: + warnings.warn( + "The parameter max_tokens will be overwritten by speculated number of tokens." + ) + continue + if key not in self.spec_kwargs: + self.spec_kwargs[key] = value + else: + assert ( + value == self.spec_kwargs[key] + ), "sampling parameters should be consistent if turn on api speculative execution." + self.spec_format.append( + {"text": "", "stop": params["stop"], "name": spec_var_name} + ) + return "", {} + def generate( self, s: StreamExecutor, sampling_params: SglSamplingParams, - name=None, + spec_var_name: str = None, ): if sampling_params.dtype is None: if self.is_chat_model: - if self.api_num_spec_tokens is None: + if s.api_num_spec_tokens is None: if not s.text_.endswith(self.chat_prefix): raise RuntimeError( "This use case is not supported if api speculative execution is off. " - "For OpenAI chat models, sgl.gen must be right after sgl.assistant." + "For OpenAI chat models, sgl.gen must be right after sgl.assistant. " "Example of adding api speculative execution: @function(api_num_spec_tokens=128)." ) prompt = s.messages_ else: - # collect assistant answer format - if "max_tokens" not in self.spec_kwargs: - self.spec_kwargs["max_tokens"] = self.api_num_spec_tokens - else: - assert ( - self.spec_kwargs["max_tokens"] == self.api_num_spec_tokens - ) - params = sampling_params.to_openai_kwargs() - for key, value in params.items(): - if key in ["stop"]: - continue - if key in ["max_tokens"]: - warnings.warn( - "The parameter max_tokens will be overwritten by speculated number of tokens." - ) - continue - if key not in self.spec_kwargs: - self.spec_kwargs[key] = value - else: - assert ( - value == self.spec_kwargs[key] - ), "sampling parameters should be consistent if turn on api speculative execution." - self.spec_format.append( - {"text": "", "stop": params["stop"], "name": name} - ) - return "", {} + return self._prepare_spec_execution(sampling_params, + s.api_num_spec_tokens, spec_var_name) else: prompt = s.text_ kwargs = sampling_params.to_openai_kwargs() comp = openai_completion( client=self.client, + token_usage=self.token_usage, is_chat=self.is_chat_model, model=self.model_name, prompt=prompt, @@ -156,6 +173,7 @@ class OpenAI(BaseBackend): kwargs.pop("stop") comp = openai_completion( client=self.client, + token_usage=self.token_usage, is_chat=self.is_chat_model, model=self.model_name, prompt=s.text_ + '"', @@ -171,6 +189,7 @@ class OpenAI(BaseBackend): kwargs.pop("stop") comp = openai_completion( client=self.client, + token_usage=self.token_usage, is_chat=self.is_chat_model, model=self.model_name, prompt=s.text_, @@ -211,14 +230,16 @@ class OpenAI(BaseBackend): self, s: StreamExecutor, ): - if self.api_num_spec_tokens is None or not s.text_.endswith(self.chat_prefix): + if s.api_num_spec_tokens is None or not s.text_.endswith(self.chat_prefix): return comp = "" if not all(x["name"] is None for x in self.spec_format): + # TODO(ying): throw errors or warnings for i in range(self.spec_max_num_tries): comp = openai_completion( client=self.client, + token_usage=self.token_usage, is_chat=self.is_chat_model, model=self.model_name, prompt=s.messages_, @@ -228,7 +249,6 @@ class OpenAI(BaseBackend): break for term in self.spec_format: - stop = term["stop"] if term["stop"] is not None else "" s.text_ += term["text"] name = term["name"] if name is not None: @@ -258,6 +278,7 @@ class OpenAI(BaseBackend): kwargs = sampling_params.to_openai_kwargs() generator = openai_completion_stream( client=self.client, + token_usage=self.token_usage, is_chat=self.is_chat_model, model=self.model_name, prompt=prompt, @@ -303,6 +324,8 @@ class OpenAI(BaseBackend): ) ret_str = ret.choices[0].text ret_token = self.tokenizer.encode(ret_str)[0] + self.token_usage.prompt_tokens += ret.usage.prompt_tokens + self.token_usage.completion_tokens= ret.usage.completion_tokens # TODO: # 1. return logits as the scores @@ -332,7 +355,7 @@ class OpenAI(BaseBackend): return decision, scores, None, None -def openai_completion(client, retries=3, is_chat=None, prompt=None, **kwargs): +def openai_completion(client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs): for attempt in range(retries): try: if is_chat: @@ -346,6 +369,9 @@ def openai_completion(client, retries=3, is_chat=None, prompt=None, **kwargs): comp = [c.text for c in ret.choices] else: comp = ret.choices[0].text + + token_usage.prompt_tokens += ret.usage.prompt_tokens + token_usage.completion_tokens += ret.usage.completion_tokens break except (openai.APIError, openai.APIConnectionError, openai.RateLimitError) as e: logger.error(f"OpenAI Error: {e}. Waiting 5 seconds...") @@ -359,16 +385,19 @@ def openai_completion(client, retries=3, is_chat=None, prompt=None, **kwargs): return comp -def openai_completion_stream(client, retries=3, is_chat=None, prompt=None, **kwargs): +def openai_completion_stream(client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs): for attempt in range(retries): try: if is_chat: if "stop" in kwargs and kwargs["stop"] is None: kwargs.pop("stop") generator = client.chat.completions.create( - messages=prompt, stream=True, **kwargs + messages=prompt, stream=True, stream_options={"include_usage": True}, + **kwargs ) for ret in generator: + if len(ret.choices) == 0: + continue try: content = ret.choices[0].delta.content except IndexError: @@ -376,11 +405,17 @@ def openai_completion_stream(client, retries=3, is_chat=None, prompt=None, **kwa yield content or "", {} else: generator = client.completions.create( - prompt=prompt, stream=True, **kwargs + prompt=prompt, stream=True, stream_options={"include_usage": True}, + **kwargs ) for ret in generator: + if len(ret.choices) == 0: + continue content = ret.choices[0].text yield content or "", {} + + token_usage.prompt_tokens += ret.usage.prompt_tokens + token_usage.completion_tokens += ret.usage.completion_tokens break except (openai.APIError, openai.APIConnectionError, openai.RateLimitError) as e: logger.error(f"OpenAI Error: {e}. Waiting 5 seconds...") diff --git a/python/sglang/lang/interpreter.py b/python/sglang/lang/interpreter.py index fc7faf846..449417337 100644 --- a/python/sglang/lang/interpreter.py +++ b/python/sglang/lang/interpreter.py @@ -196,12 +196,6 @@ class StreamExecutor: # For completion self.text_ = "" # The full text - # For speculative execution - from sglang.backend.openai import OpenAI - if isinstance(backend, OpenAI): - self.backend.set_api_num_spec_tokens(api_num_spec_tokens) - self.speculated_text = "" - # For chat self.messages_ = [] # The messages in the OpenAI API format self.chat_template = chat_template or self.backend.get_chat_template() @@ -215,6 +209,10 @@ class StreamExecutor: # For fork/join self.fork_start_text_pos = None + # For speculative execution + self.api_num_spec_tokens = api_num_spec_tokens + self.speculated_text = "" + # Worker thread self.use_thread = use_thread if self.use_thread: @@ -293,6 +291,8 @@ class StreamExecutor: exes[i].fork_start_text_pos = len(self.text_) exes[i].images_ = list(self.images_) + # TODO(ying): handle API speculative execution + return exes def text(self): @@ -399,7 +399,7 @@ class StreamExecutor: if ( self.cur_role == "assistant" - and self.backend.api_num_spec_tokens is not None + and self.api_num_spec_tokens is not None and self.backend.is_chat_model and not prefix ): @@ -435,71 +435,80 @@ class StreamExecutor: # if global_config.eager_fill_image: # self.backend.fill_image(self) + def _spec_gen(self, sampling_params): + stop = sampling_params.stop + max_new_tokens = sampling_params.max_new_tokens + meta_info = {} + + def regen(): + nonlocal meta_info + + sampling_params.max_new_tokens = max( + sampling_params.max_new_tokens, self.api_num_spec_tokens + ) + sampling_params.stop = None + self.speculated_text, meta_info = self.backend.generate( + self, sampling_params=sampling_params + ) + + def find_stop(): + if isinstance(stop, str): + return self.speculated_text.find(stop) + elif isinstance(stop, (tuple, list)): + pos = -1 + for stop_str in stop: + stop_pos = self.speculated_text.find(stop_str) + if stop_pos != -1 and (pos == -1 or stop_pos < pos): + pos = stop_pos + return pos + else: + raise Exception("Wrong type of stop in sampling parameters.") + + if stop is None: + if len(self.speculated_text) < max_new_tokens: + regen() + comp = self.speculated_text[:max_new_tokens] + self.speculated_text = self.speculated_text[max_new_tokens:] + elif isinstance(stop, (str, list, tuple)): + if self.speculated_text == "": + regen() + stop_pos = find_stop() + if stop_pos == -1: + stop_pos = min( + sampling_params.max_new_tokens, + len(self.speculated_text), + ) + comp = self.speculated_text[:stop_pos] + self.speculated_text = self.speculated_text[stop_pos:] + else: + raise ValueError("Wrong type of stop in sampling parameters.") + + return comp, meta_info + def _execute_gen(self, expr: SglGen): sampling_params = self._resolve_sampling_params(expr.sampling_params) name = expr.name if not self.stream: - if self.backend.api_num_spec_tokens is None: + if self.api_num_spec_tokens is None: comp, meta_info = self.backend.generate( self, sampling_params=sampling_params, ) - - elif self.backend.is_chat_model: - # spec on model with only chat interface - comp, meta_info = self.backend.generate( - self, - sampling_params=sampling_params, - name=name, - ) - return - - else: # spec on model with completion - stop = sampling_params.stop - max_new_tokens = sampling_params.max_new_tokens - meta_info = {} - - def regen(): - sampling_params.max_new_tokens = max( - sampling_params.max_new_tokens, self.backend.api_num_spec_tokens - ) - sampling_params.stop = None - self.speculated_text, meta_info = self.backend.generate( - self, sampling_params=sampling_params + else: + if self.backend.is_chat_model: + # Speculative execution on models with only chat interface. + # Store the calls into a temporary list. + # They will be lazily executed later. + comp, meta_info = self.backend.generate( + self, + sampling_params=sampling_params, + spec_var_name=name, ) + return - def find_stop(): - if isinstance(stop, str): - return self.speculated_text.find(stop) - elif isinstance(stop, (tuple, list)): - pos = -1 - for stop_str in stop: - stop_pos = self.speculated_text.find(stop_str) - if stop_pos != -1 and (pos == -1 or stop_pos < pos): - pos = stop_pos - return pos - else: - raise Exception("Wrong type of stop in sampling parameters.") - - if stop is None: - if len(self.speculated_text) < max_new_tokens: - regen() - comp = self.speculated_text[:max_new_tokens] - self.speculated_text = self.speculated_text[max_new_tokens:] - elif isinstance(stop, (str, list, tuple)): - if self.speculated_text == "": - regen() - stop_pos = find_stop() - if stop_pos == -1: - stop_pos = min( - sampling_params.max_new_tokens, - len(self.speculated_text), - ) - comp = self.speculated_text[:stop_pos] - self.speculated_text = self.speculated_text[stop_pos:] - else: - raise ValueError("Wrong type of stop in sampling parameters.") + else: # Speculative execution on models with completion interface + comp, meta_info = self._spec_gen(sampling_params) self.text_ += comp @@ -508,7 +517,7 @@ class StreamExecutor: self.variable_event[name].set() else: assert ( - self.backend.api_num_spec_tokens is None + self.api_num_spec_tokens is None ), "stream is not supported with api speculative execution" generator = self.backend.generate_stream( self, sampling_params=sampling_params @@ -571,9 +580,10 @@ class StreamExecutor: def _execute_role_end(self, expr: SglRoleEnd): if ( self.cur_role == "assistant" - and self.backend.api_num_spec_tokens is not None and self.backend.is_chat_model + and self.api_num_spec_tokens is not None ): + # Execute the stored lazy generation calls self.backend.role_end_generate(self) self.cur_role = None diff --git a/python/sglang/test/test_programs.py b/python/sglang/test/test_programs.py index 853ed9846..b887c09b2 100644 --- a/python/sglang/test/test_programs.py +++ b/python/sglang/test/test_programs.py @@ -304,6 +304,7 @@ def test_image_qa(): temperature=0, max_new_tokens=64, ) + assert ( "taxi" in state.messages()[-1]["content"] or "car" in state.messages()[-1]["content"] @@ -349,3 +350,46 @@ def test_regex(): state = regex_gen.run() answer = state["answer"] assert re.match(regex, answer) + + +def test_completion_speculative(): + @sgl.function(api_num_spec_tokens=64) + def gen_character_spec(s): + s += "Construct a character within the following format:\n" + s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n" + s += "\nPlease generate new Name, Birthday and Job.\n" + s += "Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") + s += "\nJob:" + sgl.gen("job", stop="\n") + "\n" + + + @sgl.function + def gen_character_no_spec(s): + s += "Construct a character within the following format:\n" + s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n" + s += "\nPlease generate new Name, Birthday and Job.\n" + s += "Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") + s += "\nJob:" + sgl.gen("job", stop="\n") + "\n" + + token_usage = sgl.global_config.default_backend.token_usage + + token_usage.reset() + gen_character_spec().sync() + usage_with_spec = token_usage.prompt_tokens + + token_usage.reset() + gen_character_no_spec().sync() + usage_with_no_spec = token_usage.prompt_tokens + + assert usage_with_spec < usage_with_no_spec, f"{usage_with_spec} vs {usage_with_no_spec}" + + +def test_chat_completion_speculative(): + @sgl.function(api_num_spec_tokens=256) + def gen_character_spec(s): + s += sgl.system("You are a helpful assistant.") + s += sgl.user("Construct a character within the following format:") + s += sgl.assistant("Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n") + s += sgl.user("Please generate new Name, Birthday and Job.\n") + s += sgl.assistant("Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n")) + + gen_character_spec().sync() \ No newline at end of file diff --git a/test/lang/test_openai_backend.py b/test/lang/test_openai_backend.py index 90d097956..b028bc0ab 100644 --- a/test/lang/test_openai_backend.py +++ b/test/lang/test_openai_backend.py @@ -14,6 +14,8 @@ from sglang.test.test_programs import ( test_select, test_stream, test_tool_use, + test_completion_speculative, + test_chat_completion_speculative ) @@ -78,6 +80,14 @@ class TestOpenAIBackend(unittest.TestCase): set_default_backend(self.backend) test_stream() + def test_completion_speculative(self): + set_default_backend(self.backend) + test_completion_speculative() + + def test_chat_completion_speculative(self): + set_default_backend(self.chat_backend) + test_chat_completion_speculative() + if __name__ == "__main__": unittest.main(warnings="ignore") @@ -87,4 +97,4 @@ if __name__ == "__main__": # global_config.verbosity = 2 # t = TestOpenAIBackend() # t.setUp() - # t.test_few_shot_qa() + # t.test_chat_completion_speculative() \ No newline at end of file