support speculative execution for openai API (#48)
Co-authored-by: Ying Sheng <sqy1415@gmail.com>
This commit is contained in:
19
examples/usage/openai_speculative.py
Normal file
19
examples/usage/openai_speculative.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
from sglang import function, gen, set_default_backend, OpenAI
|
||||||
|
|
||||||
|
|
||||||
|
@function(api_num_spec_tokens=512)
|
||||||
|
def gen_character_spec(s):
|
||||||
|
s += "Construct a character within the following format:\n"
|
||||||
|
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
|
||||||
|
s += "\nPlease generate new Name, Birthday and Job.\n"
|
||||||
|
s += "Name:" + gen("name", stop="\n") + "\nBirthday:" + gen("birthday", stop="\n")
|
||||||
|
s += "\nJob:" + gen("job", stop="\n") + "\n"
|
||||||
|
|
||||||
|
|
||||||
|
set_default_backend(OpenAI("gpt-3.5-turbo-instruct"))
|
||||||
|
|
||||||
|
state = gen_character_spec.run()
|
||||||
|
|
||||||
|
print("name:", state["name"])
|
||||||
|
print("birthday:", state["birthday"])
|
||||||
|
print("job:", state["job"])
|
||||||
@@ -20,8 +20,16 @@ from sglang.lang.ir import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def function(func: Callable):
|
def function(
|
||||||
return SglFunction(func)
|
func: Optional[Callable] = None, api_num_spec_tokens: Optional[int] = None
|
||||||
|
):
|
||||||
|
if func:
|
||||||
|
return SglFunction(func, api_num_spec_tokens=api_num_spec_tokens)
|
||||||
|
|
||||||
|
def decorator(func):
|
||||||
|
return SglFunction(func, api_num_spec_tokens=api_num_spec_tokens)
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def Runtime(*args, **kwargs):
|
def Runtime(*args, **kwargs):
|
||||||
|
|||||||
@@ -51,10 +51,14 @@ def run_program(
|
|||||||
if hasattr(backend, "endpoint"):
|
if hasattr(backend, "endpoint"):
|
||||||
backend = backend.endpoint
|
backend = backend.endpoint
|
||||||
assert backend is not None, "Please specify a backend"
|
assert backend is not None, "Please specify a backend"
|
||||||
|
|
||||||
func_kwargs.update(program.bind_arguments)
|
func_kwargs.update(program.bind_arguments)
|
||||||
stream_executor = StreamExecutor(
|
stream_executor = StreamExecutor(
|
||||||
backend, func_kwargs, default_sampling_para, chat_template=None, stream=stream
|
backend,
|
||||||
|
func_kwargs,
|
||||||
|
default_sampling_para,
|
||||||
|
chat_template=None,
|
||||||
|
stream=stream,
|
||||||
|
api_num_spec_tokens=program.api_num_spec_tokens,
|
||||||
)
|
)
|
||||||
state = ProgramState(stream_executor)
|
state = ProgramState(stream_executor)
|
||||||
|
|
||||||
@@ -175,6 +179,7 @@ class StreamExecutor:
|
|||||||
default_sampling_para,
|
default_sampling_para,
|
||||||
chat_template,
|
chat_template,
|
||||||
stream,
|
stream,
|
||||||
|
api_num_spec_tokens=None,
|
||||||
use_thread=True,
|
use_thread=True,
|
||||||
):
|
):
|
||||||
self.sid = uuid.uuid4().hex
|
self.sid = uuid.uuid4().hex
|
||||||
@@ -182,6 +187,7 @@ class StreamExecutor:
|
|||||||
self.arguments: Dict[str, Any] = arguments
|
self.arguments: Dict[str, Any] = arguments
|
||||||
self.default_sampling_para = default_sampling_para
|
self.default_sampling_para = default_sampling_para
|
||||||
self.stream = stream
|
self.stream = stream
|
||||||
|
self.api_num_spec_tokens = api_num_spec_tokens
|
||||||
|
|
||||||
self.variables = {} # Dict[name: str -> value: str]
|
self.variables = {} # Dict[name: str -> value: str]
|
||||||
self.variable_event = {} # Dict[name: str -> event: threading.Event]
|
self.variable_event = {} # Dict[name: str -> event: threading.Event]
|
||||||
@@ -191,6 +197,9 @@ class StreamExecutor:
|
|||||||
# For completion
|
# For completion
|
||||||
self.text_ = "" # The full text
|
self.text_ = "" # The full text
|
||||||
|
|
||||||
|
# For speculative execution
|
||||||
|
self.speculated_text = ""
|
||||||
|
|
||||||
# For chat
|
# For chat
|
||||||
self.messages_ = [] # The messages in the OpenAI API format
|
self.messages_ = [] # The messages in the OpenAI API format
|
||||||
self.chat_template = chat_template or self.backend.get_chat_template()
|
self.chat_template = chat_template or self.backend.get_chat_template()
|
||||||
@@ -341,6 +350,10 @@ class StreamExecutor:
|
|||||||
|
|
||||||
def _execute_fill(self, value: str):
|
def _execute_fill(self, value: str):
|
||||||
value = str(value)
|
value = str(value)
|
||||||
|
if self.speculated_text.startswith(value):
|
||||||
|
self.speculated_text = self.speculated_text[len(value) :]
|
||||||
|
else:
|
||||||
|
self.speculated_text = ""
|
||||||
self.text_ += value
|
self.text_ += value
|
||||||
|
|
||||||
def _execute_image(self, expr: SglImage):
|
def _execute_image(self, expr: SglImage):
|
||||||
@@ -360,9 +373,61 @@ class StreamExecutor:
|
|||||||
name = expr.name
|
name = expr.name
|
||||||
|
|
||||||
if not self.stream:
|
if not self.stream:
|
||||||
comp, meta_info = self.backend.generate(
|
if self.api_num_spec_tokens is not None:
|
||||||
self, sampling_params=sampling_params
|
stop = sampling_params.stop
|
||||||
)
|
max_new_tokens = sampling_params.max_new_tokens
|
||||||
|
meta_info = {}
|
||||||
|
|
||||||
|
def regen():
|
||||||
|
sampling_params.max_new_tokens = max(
|
||||||
|
sampling_params.max_new_tokens, self.api_num_spec_tokens
|
||||||
|
)
|
||||||
|
sampling_params.stop = None
|
||||||
|
self.speculated_text, meta_info = self.backend.generate(
|
||||||
|
self, sampling_params=sampling_params
|
||||||
|
)
|
||||||
|
|
||||||
|
def find_stop():
|
||||||
|
if isinstance(stop, str):
|
||||||
|
return self.speculated_text.find(stop), len(stop)
|
||||||
|
elif isinstance(stop, (tuple, list)):
|
||||||
|
pos = -1
|
||||||
|
stop_len = 0
|
||||||
|
for stop_str in stop:
|
||||||
|
stop_pos = self.speculated_text.find(stop_str)
|
||||||
|
if stop_pos != -1 and (pos == -1 or stop_pos < pos):
|
||||||
|
pos = stop_pos
|
||||||
|
stop_len = len(stop_str)
|
||||||
|
return pos, stop_len
|
||||||
|
else:
|
||||||
|
raise Exception("Wrong type of stop in sampling parameters.")
|
||||||
|
|
||||||
|
if stop is None:
|
||||||
|
if len(self.speculated_text) < max_new_tokens:
|
||||||
|
regen()
|
||||||
|
comp = self.speculated_text[:max_new_tokens]
|
||||||
|
self.speculated_text = self.speculated_text[max_new_tokens:]
|
||||||
|
elif isinstance(stop, (str, list, tuple)):
|
||||||
|
if self.speculated_text == "":
|
||||||
|
regen()
|
||||||
|
stop_pos, stop_len = find_stop()
|
||||||
|
if stop_pos == -1:
|
||||||
|
stop_pos, stop_len = (
|
||||||
|
min(
|
||||||
|
sampling_params.max_new_tokens,
|
||||||
|
len(self.speculated_text),
|
||||||
|
),
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
comp = self.speculated_text[:stop_pos]
|
||||||
|
self.speculated_text = self.speculated_text[stop_pos:]
|
||||||
|
else:
|
||||||
|
raise ValueError("Wrong type of stop in sampling parameters.")
|
||||||
|
else:
|
||||||
|
comp, meta_info = self.backend.generate(
|
||||||
|
self, sampling_params=sampling_params
|
||||||
|
)
|
||||||
|
|
||||||
self.text_ += comp
|
self.text_ += comp
|
||||||
|
|
||||||
self.variables[name] = comp
|
self.variables[name] = comp
|
||||||
|
|||||||
@@ -95,8 +95,9 @@ class SglSamplingParams:
|
|||||||
|
|
||||||
|
|
||||||
class SglFunction:
|
class SglFunction:
|
||||||
def __init__(self, func, bind_arguments=None):
|
def __init__(self, func, api_num_spec_tokens=None, bind_arguments=None):
|
||||||
self.func = func
|
self.func = func
|
||||||
|
self.api_num_spec_tokens = api_num_spec_tokens
|
||||||
self.bind_arguments = bind_arguments or {}
|
self.bind_arguments = bind_arguments or {}
|
||||||
self.pin_prefix_rid = None
|
self.pin_prefix_rid = None
|
||||||
|
|
||||||
|
|||||||
@@ -60,7 +60,9 @@ class DetokenizerManager:
|
|||||||
if first_token.startswith("▁"):
|
if first_token.startswith("▁"):
|
||||||
output_strs[i] = " " + output_strs[i]
|
output_strs[i] = " " + output_strs[i]
|
||||||
|
|
||||||
output_strs[i] = recv_obj.output_and_fast_forward_strs[i] + output_strs[i]
|
output_strs[i] = (
|
||||||
|
recv_obj.output_and_fast_forward_strs[i] + output_strs[i]
|
||||||
|
)
|
||||||
|
|
||||||
self.send_to_tokenizer.send_pyobj(
|
self.send_to_tokenizer.send_pyobj(
|
||||||
BatchStrOut(
|
BatchStrOut(
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import rpyc
|
|||||||
import torch
|
import torch
|
||||||
from rpyc.utils.classic import obtain
|
from rpyc.utils.classic import obtain
|
||||||
from rpyc.utils.server import ThreadedServer
|
from rpyc.utils.server import ThreadedServer
|
||||||
|
from sglang.srt.constrained.fast_forward import FastForwardCache
|
||||||
from sglang.srt.constrained.fsm_cache import FSMCache
|
from sglang.srt.constrained.fsm_cache import FSMCache
|
||||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||||
from sglang.srt.managers.io_struct import BatchTokenIDOut, TokenizedGenerateReqInput
|
from sglang.srt.managers.io_struct import BatchTokenIDOut, TokenizedGenerateReqInput
|
||||||
@@ -21,7 +22,6 @@ from sglang.srt.managers.router.radix_cache import RadixCache
|
|||||||
from sglang.srt.managers.router.scheduler import Scheduler
|
from sglang.srt.managers.router.scheduler import Scheduler
|
||||||
from sglang.srt.model_config import ModelConfig
|
from sglang.srt.model_config import ModelConfig
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.srt.constrained.fast_forward import FastForwardCache
|
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
get_exception_traceback,
|
get_exception_traceback,
|
||||||
get_int_token_logit_bias,
|
get_int_token_logit_bias,
|
||||||
|
|||||||
@@ -200,6 +200,7 @@ class TokenizerManager:
|
|||||||
)
|
)
|
||||||
tokenized_obj = TokenizedGenerateReqInput(
|
tokenized_obj = TokenizedGenerateReqInput(
|
||||||
rid=rid,
|
rid=rid,
|
||||||
|
input_text=obj.text[i],
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
image_hash=image_hash,
|
image_hash=image_hash,
|
||||||
|
|||||||
68
test/lang/test_openai_spec.py
Normal file
68
test/lang/test_openai_spec.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
from sglang import OpenAI, function, gen, set_default_backend
|
||||||
|
|
||||||
|
|
||||||
|
@function()
|
||||||
|
def gen_character_default(s):
|
||||||
|
s += "Construct a character within the following format:\n"
|
||||||
|
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\nWelcome.\n"
|
||||||
|
s += "\nPlease generate new Name, Birthday and Job.\n"
|
||||||
|
s += "Name:" + gen("name", stop="\n") + "\nBirthday:" + gen("birthday", stop="\n")
|
||||||
|
s += "\nJob:" + gen("job", stop="\n") + "\nWelcome.\n"
|
||||||
|
|
||||||
|
|
||||||
|
@function(api_num_spec_tokens=512)
|
||||||
|
def gen_character_spec(s):
|
||||||
|
s += "Construct a character within the following format:\n"
|
||||||
|
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\nWelcome.\n"
|
||||||
|
s += "\nPlease generate new Name, Birthday and Job.\n"
|
||||||
|
s += "Name:" + gen("name", stop="\n") + "\nBirthday:" + gen("birthday", stop="\n")
|
||||||
|
s += "\nJob:" + gen("job", stop="\n") + "\nWelcome.\n"
|
||||||
|
|
||||||
|
|
||||||
|
@function(api_num_spec_tokens=512)
|
||||||
|
def gen_character_no_stop(s):
|
||||||
|
s += "Construct a character within the following format:\n"
|
||||||
|
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\nWelcome.\n"
|
||||||
|
s += "\nPlease generate new Name, Birthday and Job.\n"
|
||||||
|
s += "Name:" + gen("name") + "\nBirthday:" + gen("birthday")
|
||||||
|
s += "\nJob:" + gen("job") + "\nWelcome.\n"
|
||||||
|
|
||||||
|
|
||||||
|
@function(api_num_spec_tokens=512)
|
||||||
|
def gen_character_multi_stop(s):
|
||||||
|
s += "Construct a character within the following format:\n"
|
||||||
|
s += (
|
||||||
|
"Name: Steve Jobs.###Birthday: February 24, 1955.###Job: Apple CEO.\nWelcome.\n"
|
||||||
|
)
|
||||||
|
s += "\nPlease generate new Name, Birthday and Job.\n"
|
||||||
|
s += "Name:" + gen("name", stop=["\n", "###"])
|
||||||
|
s += "###Birthday:" + gen("birthday", stop=["\n", "###"])
|
||||||
|
s += "###Job:" + gen("job", stop=["\n", "###"]) + "\nWelcome.\n"
|
||||||
|
|
||||||
|
|
||||||
|
set_default_backend(OpenAI("gpt-3.5-turbo-instruct"))
|
||||||
|
|
||||||
|
state = gen_character_default.run()
|
||||||
|
print(state.text())
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
state = gen_character_no_stop.run()
|
||||||
|
|
||||||
|
print("name###", state["name"])
|
||||||
|
print("birthday###:", state["birthday"])
|
||||||
|
print("job###", state["job"])
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
state = gen_character_multi_stop.run()
|
||||||
|
print(state.text())
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
state = gen_character_spec.run()
|
||||||
|
print(state.text())
|
||||||
|
|
||||||
|
print("name###", state["name"])
|
||||||
|
print("birthday###", state["birthday"])
|
||||||
|
print("job###", state["job"])
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
import argparse
|
import argparse
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
import sglang as sgl
|
|
||||||
from pydantic import BaseModel, constr
|
from pydantic import BaseModel, constr
|
||||||
from sglang.srt.constrained.json_schema import build_regex_from_object
|
from sglang.srt.constrained.json_schema import build_regex_from_object
|
||||||
from sglang.test.test_utils import (
|
from sglang.test.test_utils import (
|
||||||
@@ -9,6 +8,8 @@ from sglang.test.test_utils import (
|
|||||||
select_sglang_backend,
|
select_sglang_backend,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
import sglang as sgl
|
||||||
|
|
||||||
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
|
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
|
||||||
|
|
||||||
ip_fast_forward = (
|
ip_fast_forward = (
|
||||||
|
|||||||
@@ -2,13 +2,14 @@ import argparse
|
|||||||
import random
|
import random
|
||||||
import string
|
import string
|
||||||
|
|
||||||
import sglang as sgl
|
|
||||||
from sglang.test.test_utils import (
|
from sglang.test.test_utils import (
|
||||||
add_common_sglang_args_and_parse,
|
add_common_sglang_args_and_parse,
|
||||||
select_sglang_backend,
|
select_sglang_backend,
|
||||||
)
|
)
|
||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
|
||||||
|
import sglang as sgl
|
||||||
|
|
||||||
TOKENIZER = None
|
TOKENIZER = None
|
||||||
RANDOM_PREFILL_LEN = None
|
RANDOM_PREFILL_LEN = None
|
||||||
RANDOM_DECODE_LEN = None
|
RANDOM_DECODE_LEN = None
|
||||||
|
|||||||
Reference in New Issue
Block a user