diff --git a/examples/usage/openai_speculative.py b/examples/usage/openai_speculative.py index cb06428da..4b1d04f2a 100644 --- a/examples/usage/openai_speculative.py +++ b/examples/usage/openai_speculative.py @@ -14,10 +14,25 @@ def gen_character_spec(s): s += "\nJob:" + gen("job", stop="\n") + "\n" +@function(api_num_spec_tokens=64) +def gen_character_spec_no_few_shot(s): + # s += "Construct a character with name, birthday, and job:\n" + s += "Construct a character:\n" + s += "Name:" + gen("name", stop="\n") + "\nBirthday:" + gen("birthday", stop="\n") + s += "\nJob:" + gen("job", stop="\n") + "\n" + + set_default_backend(OpenAI("gpt-3.5-turbo-instruct")) state = gen_character_spec.run() -print("name:", state["name"]) -print("birthday:", state["birthday"]) -print("job:", state["job"]) +print("...name:", state["name"]) +print("...birthday:", state["birthday"]) +print("...job:", state["job"]) + +state = gen_character_spec_no_few_shot.run() + +print("\n...name:", state["name"]) +print("...birthday:", state["birthday"]) +print("...job:", state["job"]) + diff --git a/examples/usage/openaichat_speculative.py b/examples/usage/openaichat_speculative.py new file mode 100644 index 000000000..b6d9aff75 --- /dev/null +++ b/examples/usage/openaichat_speculative.py @@ -0,0 +1,123 @@ +""" +Usage: +***Note: for speculative execution to work, user must put all "gen" in "assistant". Show in "assistant" the desired answer format. Each "gen" term should have a stop token. The stream mode is not supported in speculative execution. +E.g. +correct: + sgl.assistant("\nName:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n")) +incorrect: + s += sgl.assistant("\nName:" + sgl.gen("name", stop="\n")) + s += sgl.assistant("\nBirthday:" + sgl.gen("birthday", stop="\n")) + s += sgl.assistant("\nJob:" + sgl.gen("job", stop="\n")) + +export OPENAI_API_KEY=sk-****** +python3 openaichat_speculative.py +""" +import sglang as sgl +from sglang import function, gen, set_default_backend, OpenAI + + +@function(api_num_spec_tokens=512) +def gen_character_spec(s): + s += sgl.system("You are a helpful assistant.") + s += sgl.user("Construct a character within the following format:") + s += sgl.assistant("Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n") + s += sgl.user("Please generate new Name, Birthday and Job.\n") + s += sgl.assistant("Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n")) + + +@function(api_num_spec_tokens=512) +def gen_character_spec_no_few_shot(s): + s += sgl.user("Construct a character. For each field stop with a newline\n") + s += sgl.assistant("Name:" + sgl.gen("name", stop="\n") + "\nAge:" + sgl.gen("age", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n")) + + +@function +def gen_character_normal(s): + s += sgl.system("You are a helpful assistant.") + s += sgl.user("What's the answer of 23 + 8?") + s += sgl.assistant(sgl.gen("answer", max_tokens=64)) + + +@function(api_num_spec_tokens=1024) +def multi_turn_question(s, question_1, question_2): + s += sgl.system("You are a helpful assistant.") + s += sgl.user("Answer questions in the following format:") + s += sgl.user("Question 1: What is the capital of France?\nQuestion 2: What is the population of this city?\n") + s += sgl.assistant("Answer 1: The capital of France is Paris.\nAnswer 2: The population of Paris in 2024 is estimated to be around 2.1 million for the city proper.\n") + s += sgl.user("Question 1: "+question_1+"\nQuestion 2: "+question_2) + s += sgl.assistant("Answer 1: "+sgl.gen("answer_1", stop="\n") + "\nAnswer 2: "+ sgl.gen("answer_2", stop="\n")) + + +def test_spec_single_turn(): + state = gen_character_spec.run() + for m in state.messages(): + print(m["role"], ":", m["content"]) + + print("\n-- name:", state["name"]) + print("\n-- birthday:", state["birthday"]) + print("\n-- job:", state["job"]) + + +def test_inaccurate_spec_single_turn(): + state = gen_character_spec_no_few_shot.run() + for m in state.messages(): + print(m["role"], ":", m["content"]) + + print("\n-- name:", state["name"]) + print("\n-- age:", state["age"]) + print("\n-- job:", state["job"]) + + +def test_normal_single_turn(): + state = gen_character_normal.run() + for m in state.messages(): + print(m["role"], ":", m["content"]) + + +def test_spec_multi_turn(): + state = multi_turn_question.run( + question_1="What is the capital of the United States?", + question_2="List two local attractions in the capital of the United States.", + ) + + for m in state.messages(): + print(m["role"], ":", m["content"]) + + print("\n-- answer_1 --\n", state["answer_1"]) + print("\n-- answer_2 --\n", state["answer_2"]) + + +def test_spec_multi_turn_stream(): + state = multi_turn_question.run( + question_1="What is the capital of the United States?", + question_2="List two local attractions.", + stream=True + ) + + for out in state.text_iter(): + print(out, end="", flush=True) + + +if __name__ == "__main__": + set_default_backend(OpenAI("gpt-4-turbo")) + + print("\n========== test spec single turn ==========\n") + # expect reasonable answer for each field + test_spec_single_turn() + + print("\n========== test inaccurate spec single turn ==========\n") + # expect incomplete or unreasonable answers + test_inaccurate_spec_single_turn() + + print("\n========== test normal single turn ==========\n") + # expect reasonable answer + test_normal_single_turn() + + print("\n========== test spec multi turn ==========\n") + # expect answer with same format as in the few shot + test_spec_multi_turn() + + print("\n========== test spec multi turn stream ==========\n") + # expect error in stream_executor: stream is not supported... + test_spec_multi_turn_stream() + diff --git a/python/sglang/backend/base_backend.py b/python/sglang/backend/base_backend.py index cb504f51b..716981365 100644 --- a/python/sglang/backend/base_backend.py +++ b/python/sglang/backend/base_backend.py @@ -9,6 +9,7 @@ class BaseBackend: def __init__(self) -> None: self.support_concate_and_append = False self.chat_template = get_chat_template("default") + self.api_num_spec_tokens = None def get_model_name(self): raise NotImplementedError() diff --git a/python/sglang/backend/openai.py b/python/sglang/backend/openai.py index 5ac1d9447..df3333523 100644 --- a/python/sglang/backend/openai.py +++ b/python/sglang/backend/openai.py @@ -1,5 +1,6 @@ import logging import time +import warnings from typing import Callable, List, Optional, Union import numpy as np @@ -80,7 +81,15 @@ class OpenAI(BaseBackend): else: self.is_chat_model = True - self.chat_begin_str = self.chat_template.role_prefix_and_suffix["assistant"][0] + self.chat_prefix = self.chat_template.role_prefix_and_suffix["assistant"][0] + + self.spec_kwargs = {} + self.spec_format = [] + self.spec_max_num_tries = 3 + self.api_num_spec_tokens = None + + def set_api_num_spec_tokens(self, num): + self.api_num_spec_tokens = num def get_chat_template(self): return self.chat_template @@ -89,15 +98,45 @@ class OpenAI(BaseBackend): self, s: StreamExecutor, sampling_params: SglSamplingParams, + name=None, ): if sampling_params.dtype is None: if self.is_chat_model: - if not s.text_.endswith(self.chat_begin_str): - raise RuntimeError( - "This use case is not supported. " - "For OpenAI chat models, sgl.gen must be right after sgl.assistant" + if self.api_num_spec_tokens is None: + if not s.text_.endswith(self.chat_prefix): + raise RuntimeError( + "This use case is not supported if api speculative execution is off. " + "For OpenAI chat models, sgl.gen must be right after sgl.assistant." + "Example of adding api speculative execution: @function(api_num_spec_tokens=128)." + ) + prompt = s.messages_ + else: + # collect assistant answer format + if "max_tokens" not in self.spec_kwargs: + self.spec_kwargs["max_tokens"] = self.api_num_spec_tokens + else: + assert ( + self.spec_kwargs["max_tokens"] == self.api_num_spec_tokens + ) + params = sampling_params.to_openai_kwargs() + for key, value in params.items(): + if key in ["stop"]: + continue + if key in ["max_tokens"]: + warnings.warn( + "The parameter max_tokens will be overwritten by speculated number of tokens." + ) + continue + if key not in self.spec_kwargs: + self.spec_kwargs[key] = value + else: + assert ( + value == self.spec_kwargs[key] + ), "sampling parameters should be consistent if turn on api speculative execution." + self.spec_format.append( + {"text": "", "stop": params["stop"], "name": name} ) - prompt = s.messages_ + return "", {} else: prompt = s.text_ @@ -110,6 +149,9 @@ class OpenAI(BaseBackend): **kwargs, ) elif sampling_params.dtype in [str, "str", "string"]: + assert ( + not self.is_chat_model + ), "constrained type not supported on chat model" kwargs = sampling_params.to_openai_kwargs() kwargs.pop("stop") comp = openai_completion( @@ -122,6 +164,9 @@ class OpenAI(BaseBackend): ) comp = '"' + comp + '"' elif sampling_params.dtype in [int, "int"]: + assert ( + not self.is_chat_model + ), "constrained type not supported on chat model" kwargs = sampling_params.to_openai_kwargs() kwargs.pop("stop") comp = openai_completion( @@ -138,6 +183,62 @@ class OpenAI(BaseBackend): return comp, {} + def spec_fill(self, value: str): + assert self.is_chat_model + self.spec_format.append({"text": value, "stop": None, "name": None}) + + def spec_pattern_match(self, comp): + for i, term in enumerate(self.spec_format): + text = term["text"] + if text != "": + if comp.startswith(text): + comp = comp[len(text) :] + else: + return False + else: + pos = comp.find(term["stop"]) + if pos != -1: + term["text"] = comp[:pos] + comp = comp[pos:] + else: + if i == len(self.spec_format) - 1: + term["text"] = comp + else: + return False + return True + + def role_end_generate( + self, + s: StreamExecutor, + ): + if self.api_num_spec_tokens is None or not s.text_.endswith(self.chat_prefix): + return + + comp = "" + if not all(x["name"] is None for x in self.spec_format): + for i in range(self.spec_max_num_tries): + comp = openai_completion( + client=self.client, + is_chat=self.is_chat_model, + model=self.model_name, + prompt=s.messages_, + **self.spec_kwargs, + ) + if self.spec_pattern_match(comp): + break + + for term in self.spec_format: + stop = term["stop"] if term["stop"] is not None else "" + s.text_ += term["text"] + name = term["name"] + if name is not None: + s.variables[name] = term["text"] + s.meta_info[name] = {} + s.variable_event[name].set() + + self.spec_kwargs = {} + self.spec_format = [] + def generate_stream( self, s: StreamExecutor, @@ -145,7 +246,7 @@ class OpenAI(BaseBackend): ): if sampling_params.dtype is None: if self.is_chat_model: - if not s.text_.endswith(self.chat_begin_str): + if not s.text_.endswith(self.chat_prefix): raise RuntimeError( "This use case is not supported. " "For OpenAI chat models, sgl.gen must be right after sgl.assistant" diff --git a/python/sglang/backend/runtime_endpoint.py b/python/sglang/backend/runtime_endpoint.py index 42b4312f6..e0211c3b5 100644 --- a/python/sglang/backend/runtime_endpoint.py +++ b/python/sglang/backend/runtime_endpoint.py @@ -266,4 +266,4 @@ class RuntimeEndpoint(BaseBackend): def _assert_success(self, res): if res.status_code != 200: - raise RuntimeError(res.json()) \ No newline at end of file + raise RuntimeError(res.json()) diff --git a/python/sglang/lang/interpreter.py b/python/sglang/lang/interpreter.py index 7c638cba0..fc7faf846 100644 --- a/python/sglang/lang/interpreter.py +++ b/python/sglang/lang/interpreter.py @@ -6,6 +6,7 @@ import multiprocessing import queue import threading import uuid +import warnings from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager from typing import Any, Callable, Dict, List, Optional, Union @@ -185,7 +186,6 @@ class StreamExecutor: self.arguments: Dict[str, Any] = arguments self.default_sampling_para = default_sampling_para self.stream = stream - self.api_num_spec_tokens = api_num_spec_tokens self.variables = {} # Dict[name: str -> value: str] self.variable_event = {} # Dict[name: str -> event: threading.Event] @@ -197,6 +197,9 @@ class StreamExecutor: self.text_ = "" # The full text # For speculative execution + from sglang.backend.openai import OpenAI + if isinstance(backend, OpenAI): + self.backend.set_api_num_spec_tokens(api_num_spec_tokens) self.speculated_text = "" # For chat @@ -322,7 +325,7 @@ class StreamExecutor: try: self._execute(expr) except Exception as e: - # print(f"Error in stream_executor: {get_exception_traceback()}") + warnings.warn(f"Error in stream_executor: {get_exception_traceback()}") error = e break self.queue.task_done() @@ -391,12 +394,23 @@ class StreamExecutor: else: raise ValueError(f"Unknown type: {type(other)}") - def _execute_fill(self, value: str): + def _execute_fill(self, value: str, prefix=False): value = str(value) + + if ( + self.cur_role == "assistant" + and self.backend.api_num_spec_tokens is not None + and self.backend.is_chat_model + and not prefix + ): + self.backend.spec_fill(value) + return + if self.speculated_text.startswith(value): self.speculated_text = self.speculated_text[len(value) :] else: self.speculated_text = "" + self.text_ += value def _execute_image(self, expr: SglImage): @@ -426,14 +440,29 @@ class StreamExecutor: name = expr.name if not self.stream: - if self.api_num_spec_tokens is not None: + if self.backend.api_num_spec_tokens is None: + comp, meta_info = self.backend.generate( + self, + sampling_params=sampling_params, + ) + + elif self.backend.is_chat_model: + # spec on model with only chat interface + comp, meta_info = self.backend.generate( + self, + sampling_params=sampling_params, + name=name, + ) + return + + else: # spec on model with completion stop = sampling_params.stop max_new_tokens = sampling_params.max_new_tokens meta_info = {} def regen(): sampling_params.max_new_tokens = max( - sampling_params.max_new_tokens, self.api_num_spec_tokens + sampling_params.max_new_tokens, self.backend.api_num_spec_tokens ) sampling_params.stop = None self.speculated_text, meta_info = self.backend.generate( @@ -442,16 +471,14 @@ class StreamExecutor: def find_stop(): if isinstance(stop, str): - return self.speculated_text.find(stop), len(stop) + return self.speculated_text.find(stop) elif isinstance(stop, (tuple, list)): pos = -1 - stop_len = 0 for stop_str in stop: stop_pos = self.speculated_text.find(stop_str) if stop_pos != -1 and (pos == -1 or stop_pos < pos): pos = stop_pos - stop_len = len(stop_str) - return pos, stop_len + return pos else: raise Exception("Wrong type of stop in sampling parameters.") @@ -463,23 +490,16 @@ class StreamExecutor: elif isinstance(stop, (str, list, tuple)): if self.speculated_text == "": regen() - stop_pos, stop_len = find_stop() + stop_pos = find_stop() if stop_pos == -1: - stop_pos, stop_len = ( - min( - sampling_params.max_new_tokens, - len(self.speculated_text), - ), - 0, + stop_pos = min( + sampling_params.max_new_tokens, + len(self.speculated_text), ) comp = self.speculated_text[:stop_pos] self.speculated_text = self.speculated_text[stop_pos:] else: raise ValueError("Wrong type of stop in sampling parameters.") - else: - comp, meta_info = self.backend.generate( - self, sampling_params=sampling_params - ) self.text_ += comp @@ -487,6 +507,9 @@ class StreamExecutor: self.meta_info[name] = meta_info self.variable_event[name].set() else: + assert ( + self.backend.api_num_spec_tokens is None + ), "stream is not supported with api speculative execution" generator = self.backend.generate_stream( self, sampling_params=sampling_params ) @@ -542,10 +565,18 @@ class StreamExecutor: prefix, _ = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_) - self._execute_fill(prefix) + self._execute_fill(prefix, prefix=True) self.cur_role_begin_pos = len(self.text_) def _execute_role_end(self, expr: SglRoleEnd): + if ( + self.cur_role == "assistant" + and self.backend.api_num_spec_tokens is not None + and self.backend.is_chat_model + ): + self.backend.role_end_generate(self) + self.cur_role = None + new_text = self.text_[self.cur_role_begin_pos :].lstrip() _, suffix = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_) @@ -572,8 +603,6 @@ class StreamExecutor: # OpenAI chat API format self.messages_.append({"role": expr.role, "content": new_text}) - self.cur_role = None - def _execute_var_scope_begin(self, expr: SglVarScopeBegin): self.variables[expr.name] = int(len(self.text_)) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 3439747f5..8da2317c1 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -31,8 +31,9 @@ class GenerateReqInput: def post_init(self): - if ((self.text is None and self.input_ids is None) or - (self.text is not None and self.input_ids is not None)): + if (self.text is None and self.input_ids is None) or ( + self.text is not None and self.input_ids is not None + ): raise ValueError("Either text or input_ids should be provided.") if self.text is not None: diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index 00564676b..7f08313a5 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -38,7 +38,6 @@ from sglang.srt.utils import ( ) from sglang.utils import get_exception_traceback - logger = logging.getLogger("model_rpc") vllm_default_logger.setLevel(logging.WARN) logging.getLogger("vllm.utils").setLevel(logging.WARN) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 69ed86792..074b11c00 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -341,7 +341,6 @@ class TokenizerManager: return top_logprobs - global global_processor @@ -385,4 +384,4 @@ def get_pixel_values( pixel_values = pixel_values.astype(np.float16) return pixel_values, image_hash, image.size except Exception: - print("Exception in TokenizerManager:\n" + get_exception_traceback()) \ No newline at end of file + print("Exception in TokenizerManager:\n" + get_exception_traceback()) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 46e691283..54bbbfc3d 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -9,8 +9,8 @@ import os import sys import threading import time -from typing import List, Optional, Union from http import HTTPStatus +from typing import List, Optional, Union # Fix a bug of Python threading setattr(threading, "_register_atexit", lambda *args, **kwargs: None) @@ -45,7 +45,6 @@ from sglang.srt.utils import ( ) from sglang.utils import get_exception_traceback - asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) @@ -84,6 +83,7 @@ async def flush_cache(): async def generate_request(obj: GenerateReqInput, request: Request): if obj.stream: + async def stream_results(): try: async for out in tokenizer_manager.generate_request(obj, request): @@ -99,8 +99,10 @@ async def generate_request(obj: GenerateReqInput, request: Request): ret = await tokenizer_manager.generate_request(obj, request).__anext__() return ret except ValueError as e: - return JSONResponse({"error": {"message": str(e)}}, - status_code=HTTPStatus.BAD_REQUEST) + return JSONResponse( + {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST + ) + app.post("/generate")(generate_request) app.put("/generate")(generate_request) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index fbd98b3bb..950915e9f 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -19,7 +19,6 @@ from packaging import version as pkg_version from pydantic import BaseModel from starlette.middleware.base import BaseHTTPMiddleware - logger = logging.getLogger(__name__) @@ -157,7 +156,9 @@ def allocate_init_ports( cur_port += 1 if port and ret_ports[0] != port: - logger.warn(f"WARNING: Port {port} is not available. Use port {ret_ports[0]} instead.") + logger.warn( + f"WARNING: Port {port} is not available. Use port {ret_ports[0]} instead." + ) return ret_ports[0], ret_ports[1:]