diff --git a/examples/usage/openai_speculative.py b/examples/usage/openai_speculative.py new file mode 100644 index 000000000..8fb85255c --- /dev/null +++ b/examples/usage/openai_speculative.py @@ -0,0 +1,19 @@ +from sglang import function, gen, set_default_backend, OpenAI + + +@function(api_num_spec_tokens=512) +def gen_character_spec(s): + s += "Construct a character within the following format:\n" + s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n" + s += "\nPlease generate new Name, Birthday and Job.\n" + s += "Name:" + gen("name", stop="\n") + "\nBirthday:" + gen("birthday", stop="\n") + s += "\nJob:" + gen("job", stop="\n") + "\n" + + +set_default_backend(OpenAI("gpt-3.5-turbo-instruct")) + +state = gen_character_spec.run() + +print("name:", state["name"]) +print("birthday:", state["birthday"]) +print("job:", state["job"]) diff --git a/python/sglang/api.py b/python/sglang/api.py index 410cb6fb4..de36287c2 100644 --- a/python/sglang/api.py +++ b/python/sglang/api.py @@ -20,8 +20,16 @@ from sglang.lang.ir import ( ) -def function(func: Callable): - return SglFunction(func) +def function( + func: Optional[Callable] = None, api_num_spec_tokens: Optional[int] = None +): + if func: + return SglFunction(func, api_num_spec_tokens=api_num_spec_tokens) + + def decorator(func): + return SglFunction(func, api_num_spec_tokens=api_num_spec_tokens) + + return decorator def Runtime(*args, **kwargs): diff --git a/python/sglang/lang/interpreter.py b/python/sglang/lang/interpreter.py index 144a4e79f..f281b0b2a 100644 --- a/python/sglang/lang/interpreter.py +++ b/python/sglang/lang/interpreter.py @@ -51,10 +51,14 @@ def run_program( if hasattr(backend, "endpoint"): backend = backend.endpoint assert backend is not None, "Please specify a backend" - func_kwargs.update(program.bind_arguments) stream_executor = StreamExecutor( - backend, func_kwargs, default_sampling_para, chat_template=None, stream=stream + backend, + func_kwargs, + default_sampling_para, + chat_template=None, + stream=stream, + api_num_spec_tokens=program.api_num_spec_tokens, ) state = ProgramState(stream_executor) @@ -175,6 +179,7 @@ class StreamExecutor: default_sampling_para, chat_template, stream, + api_num_spec_tokens=None, use_thread=True, ): self.sid = uuid.uuid4().hex @@ -182,6 +187,7 @@ class StreamExecutor: self.arguments: Dict[str, Any] = arguments self.default_sampling_para = default_sampling_para self.stream = stream + self.api_num_spec_tokens = api_num_spec_tokens self.variables = {} # Dict[name: str -> value: str] self.variable_event = {} # Dict[name: str -> event: threading.Event] @@ -191,6 +197,9 @@ class StreamExecutor: # For completion self.text_ = "" # The full text + # For speculative execution + self.speculated_text = "" + # For chat self.messages_ = [] # The messages in the OpenAI API format self.chat_template = chat_template or self.backend.get_chat_template() @@ -341,6 +350,10 @@ class StreamExecutor: def _execute_fill(self, value: str): value = str(value) + if self.speculated_text.startswith(value): + self.speculated_text = self.speculated_text[len(value) :] + else: + self.speculated_text = "" self.text_ += value def _execute_image(self, expr: SglImage): @@ -360,9 +373,61 @@ class StreamExecutor: name = expr.name if not self.stream: - comp, meta_info = self.backend.generate( - self, sampling_params=sampling_params - ) + if self.api_num_spec_tokens is not None: + stop = sampling_params.stop + max_new_tokens = sampling_params.max_new_tokens + meta_info = {} + + def regen(): + sampling_params.max_new_tokens = max( + sampling_params.max_new_tokens, self.api_num_spec_tokens + ) + sampling_params.stop = None + self.speculated_text, meta_info = self.backend.generate( + self, sampling_params=sampling_params + ) + + def find_stop(): + if isinstance(stop, str): + return self.speculated_text.find(stop), len(stop) + elif isinstance(stop, (tuple, list)): + pos = -1 + stop_len = 0 + for stop_str in stop: + stop_pos = self.speculated_text.find(stop_str) + if stop_pos != -1 and (pos == -1 or stop_pos < pos): + pos = stop_pos + stop_len = len(stop_str) + return pos, stop_len + else: + raise Exception("Wrong type of stop in sampling parameters.") + + if stop is None: + if len(self.speculated_text) < max_new_tokens: + regen() + comp = self.speculated_text[:max_new_tokens] + self.speculated_text = self.speculated_text[max_new_tokens:] + elif isinstance(stop, (str, list, tuple)): + if self.speculated_text == "": + regen() + stop_pos, stop_len = find_stop() + if stop_pos == -1: + stop_pos, stop_len = ( + min( + sampling_params.max_new_tokens, + len(self.speculated_text), + ), + 0, + ) + comp = self.speculated_text[:stop_pos] + self.speculated_text = self.speculated_text[stop_pos:] + else: + raise ValueError("Wrong type of stop in sampling parameters.") + else: + comp, meta_info = self.backend.generate( + self, sampling_params=sampling_params + ) + self.text_ += comp self.variables[name] = comp diff --git a/python/sglang/lang/ir.py b/python/sglang/lang/ir.py index 6202bffaf..2803de51d 100644 --- a/python/sglang/lang/ir.py +++ b/python/sglang/lang/ir.py @@ -95,8 +95,9 @@ class SglSamplingParams: class SglFunction: - def __init__(self, func, bind_arguments=None): + def __init__(self, func, api_num_spec_tokens=None, bind_arguments=None): self.func = func + self.api_num_spec_tokens = api_num_spec_tokens self.bind_arguments = bind_arguments or {} self.pin_prefix_rid = None diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 9670d7dda..20c8119c7 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -60,7 +60,9 @@ class DetokenizerManager: if first_token.startswith("▁"): output_strs[i] = " " + output_strs[i] - output_strs[i] = recv_obj.output_and_fast_forward_strs[i] + output_strs[i] + output_strs[i] = ( + recv_obj.output_and_fast_forward_strs[i] + output_strs[i] + ) self.send_to_tokenizer.send_pyobj( BatchStrOut( diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index 88f7291d1..dfaa8f12b 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -12,6 +12,7 @@ import rpyc import torch from rpyc.utils.classic import obtain from rpyc.utils.server import ThreadedServer +from sglang.srt.constrained.fast_forward import FastForwardCache from sglang.srt.constrained.fsm_cache import FSMCache from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.managers.io_struct import BatchTokenIDOut, TokenizedGenerateReqInput @@ -21,7 +22,6 @@ from sglang.srt.managers.router.radix_cache import RadixCache from sglang.srt.managers.router.scheduler import Scheduler from sglang.srt.model_config import ModelConfig from sglang.srt.server_args import PortArgs, ServerArgs -from sglang.srt.constrained.fast_forward import FastForwardCache from sglang.srt.utils import ( get_exception_traceback, get_int_token_logit_bias, diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index a6c49c45d..85a7a04f6 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -200,6 +200,7 @@ class TokenizerManager: ) tokenized_obj = TokenizedGenerateReqInput( rid=rid, + input_text=obj.text[i], input_ids=input_ids, pixel_values=pixel_values, image_hash=image_hash, diff --git a/test/lang/test_openai_spec.py b/test/lang/test_openai_spec.py new file mode 100644 index 000000000..fdda19127 --- /dev/null +++ b/test/lang/test_openai_spec.py @@ -0,0 +1,68 @@ +from sglang import OpenAI, function, gen, set_default_backend + + +@function() +def gen_character_default(s): + s += "Construct a character within the following format:\n" + s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\nWelcome.\n" + s += "\nPlease generate new Name, Birthday and Job.\n" + s += "Name:" + gen("name", stop="\n") + "\nBirthday:" + gen("birthday", stop="\n") + s += "\nJob:" + gen("job", stop="\n") + "\nWelcome.\n" + + +@function(api_num_spec_tokens=512) +def gen_character_spec(s): + s += "Construct a character within the following format:\n" + s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\nWelcome.\n" + s += "\nPlease generate new Name, Birthday and Job.\n" + s += "Name:" + gen("name", stop="\n") + "\nBirthday:" + gen("birthday", stop="\n") + s += "\nJob:" + gen("job", stop="\n") + "\nWelcome.\n" + + +@function(api_num_spec_tokens=512) +def gen_character_no_stop(s): + s += "Construct a character within the following format:\n" + s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\nWelcome.\n" + s += "\nPlease generate new Name, Birthday and Job.\n" + s += "Name:" + gen("name") + "\nBirthday:" + gen("birthday") + s += "\nJob:" + gen("job") + "\nWelcome.\n" + + +@function(api_num_spec_tokens=512) +def gen_character_multi_stop(s): + s += "Construct a character within the following format:\n" + s += ( + "Name: Steve Jobs.###Birthday: February 24, 1955.###Job: Apple CEO.\nWelcome.\n" + ) + s += "\nPlease generate new Name, Birthday and Job.\n" + s += "Name:" + gen("name", stop=["\n", "###"]) + s += "###Birthday:" + gen("birthday", stop=["\n", "###"]) + s += "###Job:" + gen("job", stop=["\n", "###"]) + "\nWelcome.\n" + + +set_default_backend(OpenAI("gpt-3.5-turbo-instruct")) + +state = gen_character_default.run() +print(state.text()) + +print("=" * 60) + +state = gen_character_no_stop.run() + +print("name###", state["name"]) +print("birthday###:", state["birthday"]) +print("job###", state["job"]) + +print("=" * 60) + +state = gen_character_multi_stop.run() +print(state.text()) + +print("=" * 60) + +state = gen_character_spec.run() +print(state.text()) + +print("name###", state["name"]) +print("birthday###", state["birthday"]) +print("job###", state["job"]) diff --git a/test/srt/test_fast_forward.py b/test/srt/test_fast_forward.py index a94d15ae5..505ce2928 100644 --- a/test/srt/test_fast_forward.py +++ b/test/srt/test_fast_forward.py @@ -1,7 +1,6 @@ import argparse from enum import Enum -import sglang as sgl from pydantic import BaseModel, constr from sglang.srt.constrained.json_schema import build_regex_from_object from sglang.test.test_utils import ( @@ -9,6 +8,8 @@ from sglang.test.test_utils import ( select_sglang_backend, ) +import sglang as sgl + IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)" ip_fast_forward = ( diff --git a/test/srt/test_robust.py b/test/srt/test_robust.py index 9b4ceaf5f..5b479318f 100644 --- a/test/srt/test_robust.py +++ b/test/srt/test_robust.py @@ -2,13 +2,14 @@ import argparse import random import string -import sglang as sgl from sglang.test.test_utils import ( add_common_sglang_args_and_parse, select_sglang_backend, ) from vllm.transformers_utils.tokenizer import get_tokenizer +import sglang as sgl + TOKENIZER = None RANDOM_PREFILL_LEN = None RANDOM_DECODE_LEN = None