openai chat speculative execution (#250)
Co-authored-by: Ying Sheng <sqy1415@gmail.com>
This commit is contained in:
@@ -9,6 +9,7 @@ class BaseBackend:
|
||||
def __init__(self) -> None:
|
||||
self.support_concate_and_append = False
|
||||
self.chat_template = get_chat_template("default")
|
||||
self.api_num_spec_tokens = None
|
||||
|
||||
def get_model_name(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
import time
|
||||
import warnings
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -80,7 +81,15 @@ class OpenAI(BaseBackend):
|
||||
else:
|
||||
self.is_chat_model = True
|
||||
|
||||
self.chat_begin_str = self.chat_template.role_prefix_and_suffix["assistant"][0]
|
||||
self.chat_prefix = self.chat_template.role_prefix_and_suffix["assistant"][0]
|
||||
|
||||
self.spec_kwargs = {}
|
||||
self.spec_format = []
|
||||
self.spec_max_num_tries = 3
|
||||
self.api_num_spec_tokens = None
|
||||
|
||||
def set_api_num_spec_tokens(self, num):
|
||||
self.api_num_spec_tokens = num
|
||||
|
||||
def get_chat_template(self):
|
||||
return self.chat_template
|
||||
@@ -89,15 +98,45 @@ class OpenAI(BaseBackend):
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SglSamplingParams,
|
||||
name=None,
|
||||
):
|
||||
if sampling_params.dtype is None:
|
||||
if self.is_chat_model:
|
||||
if not s.text_.endswith(self.chat_begin_str):
|
||||
raise RuntimeError(
|
||||
"This use case is not supported. "
|
||||
"For OpenAI chat models, sgl.gen must be right after sgl.assistant"
|
||||
if self.api_num_spec_tokens is None:
|
||||
if not s.text_.endswith(self.chat_prefix):
|
||||
raise RuntimeError(
|
||||
"This use case is not supported if api speculative execution is off. "
|
||||
"For OpenAI chat models, sgl.gen must be right after sgl.assistant."
|
||||
"Example of adding api speculative execution: @function(api_num_spec_tokens=128)."
|
||||
)
|
||||
prompt = s.messages_
|
||||
else:
|
||||
# collect assistant answer format
|
||||
if "max_tokens" not in self.spec_kwargs:
|
||||
self.spec_kwargs["max_tokens"] = self.api_num_spec_tokens
|
||||
else:
|
||||
assert (
|
||||
self.spec_kwargs["max_tokens"] == self.api_num_spec_tokens
|
||||
)
|
||||
params = sampling_params.to_openai_kwargs()
|
||||
for key, value in params.items():
|
||||
if key in ["stop"]:
|
||||
continue
|
||||
if key in ["max_tokens"]:
|
||||
warnings.warn(
|
||||
"The parameter max_tokens will be overwritten by speculated number of tokens."
|
||||
)
|
||||
continue
|
||||
if key not in self.spec_kwargs:
|
||||
self.spec_kwargs[key] = value
|
||||
else:
|
||||
assert (
|
||||
value == self.spec_kwargs[key]
|
||||
), "sampling parameters should be consistent if turn on api speculative execution."
|
||||
self.spec_format.append(
|
||||
{"text": "", "stop": params["stop"], "name": name}
|
||||
)
|
||||
prompt = s.messages_
|
||||
return "", {}
|
||||
else:
|
||||
prompt = s.text_
|
||||
|
||||
@@ -110,6 +149,9 @@ class OpenAI(BaseBackend):
|
||||
**kwargs,
|
||||
)
|
||||
elif sampling_params.dtype in [str, "str", "string"]:
|
||||
assert (
|
||||
not self.is_chat_model
|
||||
), "constrained type not supported on chat model"
|
||||
kwargs = sampling_params.to_openai_kwargs()
|
||||
kwargs.pop("stop")
|
||||
comp = openai_completion(
|
||||
@@ -122,6 +164,9 @@ class OpenAI(BaseBackend):
|
||||
)
|
||||
comp = '"' + comp + '"'
|
||||
elif sampling_params.dtype in [int, "int"]:
|
||||
assert (
|
||||
not self.is_chat_model
|
||||
), "constrained type not supported on chat model"
|
||||
kwargs = sampling_params.to_openai_kwargs()
|
||||
kwargs.pop("stop")
|
||||
comp = openai_completion(
|
||||
@@ -138,6 +183,62 @@ class OpenAI(BaseBackend):
|
||||
|
||||
return comp, {}
|
||||
|
||||
def spec_fill(self, value: str):
|
||||
assert self.is_chat_model
|
||||
self.spec_format.append({"text": value, "stop": None, "name": None})
|
||||
|
||||
def spec_pattern_match(self, comp):
|
||||
for i, term in enumerate(self.spec_format):
|
||||
text = term["text"]
|
||||
if text != "":
|
||||
if comp.startswith(text):
|
||||
comp = comp[len(text) :]
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
pos = comp.find(term["stop"])
|
||||
if pos != -1:
|
||||
term["text"] = comp[:pos]
|
||||
comp = comp[pos:]
|
||||
else:
|
||||
if i == len(self.spec_format) - 1:
|
||||
term["text"] = comp
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def role_end_generate(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
):
|
||||
if self.api_num_spec_tokens is None or not s.text_.endswith(self.chat_prefix):
|
||||
return
|
||||
|
||||
comp = ""
|
||||
if not all(x["name"] is None for x in self.spec_format):
|
||||
for i in range(self.spec_max_num_tries):
|
||||
comp = openai_completion(
|
||||
client=self.client,
|
||||
is_chat=self.is_chat_model,
|
||||
model=self.model_name,
|
||||
prompt=s.messages_,
|
||||
**self.spec_kwargs,
|
||||
)
|
||||
if self.spec_pattern_match(comp):
|
||||
break
|
||||
|
||||
for term in self.spec_format:
|
||||
stop = term["stop"] if term["stop"] is not None else ""
|
||||
s.text_ += term["text"]
|
||||
name = term["name"]
|
||||
if name is not None:
|
||||
s.variables[name] = term["text"]
|
||||
s.meta_info[name] = {}
|
||||
s.variable_event[name].set()
|
||||
|
||||
self.spec_kwargs = {}
|
||||
self.spec_format = []
|
||||
|
||||
def generate_stream(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
@@ -145,7 +246,7 @@ class OpenAI(BaseBackend):
|
||||
):
|
||||
if sampling_params.dtype is None:
|
||||
if self.is_chat_model:
|
||||
if not s.text_.endswith(self.chat_begin_str):
|
||||
if not s.text_.endswith(self.chat_prefix):
|
||||
raise RuntimeError(
|
||||
"This use case is not supported. "
|
||||
"For OpenAI chat models, sgl.gen must be right after sgl.assistant"
|
||||
|
||||
@@ -266,4 +266,4 @@ class RuntimeEndpoint(BaseBackend):
|
||||
|
||||
def _assert_success(self, res):
|
||||
if res.status_code != 200:
|
||||
raise RuntimeError(res.json())
|
||||
raise RuntimeError(res.json())
|
||||
|
||||
@@ -6,6 +6,7 @@ import multiprocessing
|
||||
import queue
|
||||
import threading
|
||||
import uuid
|
||||
import warnings
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
@@ -185,7 +186,6 @@ class StreamExecutor:
|
||||
self.arguments: Dict[str, Any] = arguments
|
||||
self.default_sampling_para = default_sampling_para
|
||||
self.stream = stream
|
||||
self.api_num_spec_tokens = api_num_spec_tokens
|
||||
|
||||
self.variables = {} # Dict[name: str -> value: str]
|
||||
self.variable_event = {} # Dict[name: str -> event: threading.Event]
|
||||
@@ -197,6 +197,9 @@ class StreamExecutor:
|
||||
self.text_ = "" # The full text
|
||||
|
||||
# For speculative execution
|
||||
from sglang.backend.openai import OpenAI
|
||||
if isinstance(backend, OpenAI):
|
||||
self.backend.set_api_num_spec_tokens(api_num_spec_tokens)
|
||||
self.speculated_text = ""
|
||||
|
||||
# For chat
|
||||
@@ -322,7 +325,7 @@ class StreamExecutor:
|
||||
try:
|
||||
self._execute(expr)
|
||||
except Exception as e:
|
||||
# print(f"Error in stream_executor: {get_exception_traceback()}")
|
||||
warnings.warn(f"Error in stream_executor: {get_exception_traceback()}")
|
||||
error = e
|
||||
break
|
||||
self.queue.task_done()
|
||||
@@ -391,12 +394,23 @@ class StreamExecutor:
|
||||
else:
|
||||
raise ValueError(f"Unknown type: {type(other)}")
|
||||
|
||||
def _execute_fill(self, value: str):
|
||||
def _execute_fill(self, value: str, prefix=False):
|
||||
value = str(value)
|
||||
|
||||
if (
|
||||
self.cur_role == "assistant"
|
||||
and self.backend.api_num_spec_tokens is not None
|
||||
and self.backend.is_chat_model
|
||||
and not prefix
|
||||
):
|
||||
self.backend.spec_fill(value)
|
||||
return
|
||||
|
||||
if self.speculated_text.startswith(value):
|
||||
self.speculated_text = self.speculated_text[len(value) :]
|
||||
else:
|
||||
self.speculated_text = ""
|
||||
|
||||
self.text_ += value
|
||||
|
||||
def _execute_image(self, expr: SglImage):
|
||||
@@ -426,14 +440,29 @@ class StreamExecutor:
|
||||
name = expr.name
|
||||
|
||||
if not self.stream:
|
||||
if self.api_num_spec_tokens is not None:
|
||||
if self.backend.api_num_spec_tokens is None:
|
||||
comp, meta_info = self.backend.generate(
|
||||
self,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
|
||||
elif self.backend.is_chat_model:
|
||||
# spec on model with only chat interface
|
||||
comp, meta_info = self.backend.generate(
|
||||
self,
|
||||
sampling_params=sampling_params,
|
||||
name=name,
|
||||
)
|
||||
return
|
||||
|
||||
else: # spec on model with completion
|
||||
stop = sampling_params.stop
|
||||
max_new_tokens = sampling_params.max_new_tokens
|
||||
meta_info = {}
|
||||
|
||||
def regen():
|
||||
sampling_params.max_new_tokens = max(
|
||||
sampling_params.max_new_tokens, self.api_num_spec_tokens
|
||||
sampling_params.max_new_tokens, self.backend.api_num_spec_tokens
|
||||
)
|
||||
sampling_params.stop = None
|
||||
self.speculated_text, meta_info = self.backend.generate(
|
||||
@@ -442,16 +471,14 @@ class StreamExecutor:
|
||||
|
||||
def find_stop():
|
||||
if isinstance(stop, str):
|
||||
return self.speculated_text.find(stop), len(stop)
|
||||
return self.speculated_text.find(stop)
|
||||
elif isinstance(stop, (tuple, list)):
|
||||
pos = -1
|
||||
stop_len = 0
|
||||
for stop_str in stop:
|
||||
stop_pos = self.speculated_text.find(stop_str)
|
||||
if stop_pos != -1 and (pos == -1 or stop_pos < pos):
|
||||
pos = stop_pos
|
||||
stop_len = len(stop_str)
|
||||
return pos, stop_len
|
||||
return pos
|
||||
else:
|
||||
raise Exception("Wrong type of stop in sampling parameters.")
|
||||
|
||||
@@ -463,23 +490,16 @@ class StreamExecutor:
|
||||
elif isinstance(stop, (str, list, tuple)):
|
||||
if self.speculated_text == "":
|
||||
regen()
|
||||
stop_pos, stop_len = find_stop()
|
||||
stop_pos = find_stop()
|
||||
if stop_pos == -1:
|
||||
stop_pos, stop_len = (
|
||||
min(
|
||||
sampling_params.max_new_tokens,
|
||||
len(self.speculated_text),
|
||||
),
|
||||
0,
|
||||
stop_pos = min(
|
||||
sampling_params.max_new_tokens,
|
||||
len(self.speculated_text),
|
||||
)
|
||||
comp = self.speculated_text[:stop_pos]
|
||||
self.speculated_text = self.speculated_text[stop_pos:]
|
||||
else:
|
||||
raise ValueError("Wrong type of stop in sampling parameters.")
|
||||
else:
|
||||
comp, meta_info = self.backend.generate(
|
||||
self, sampling_params=sampling_params
|
||||
)
|
||||
|
||||
self.text_ += comp
|
||||
|
||||
@@ -487,6 +507,9 @@ class StreamExecutor:
|
||||
self.meta_info[name] = meta_info
|
||||
self.variable_event[name].set()
|
||||
else:
|
||||
assert (
|
||||
self.backend.api_num_spec_tokens is None
|
||||
), "stream is not supported with api speculative execution"
|
||||
generator = self.backend.generate_stream(
|
||||
self, sampling_params=sampling_params
|
||||
)
|
||||
@@ -542,10 +565,18 @@ class StreamExecutor:
|
||||
|
||||
prefix, _ = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_)
|
||||
|
||||
self._execute_fill(prefix)
|
||||
self._execute_fill(prefix, prefix=True)
|
||||
self.cur_role_begin_pos = len(self.text_)
|
||||
|
||||
def _execute_role_end(self, expr: SglRoleEnd):
|
||||
if (
|
||||
self.cur_role == "assistant"
|
||||
and self.backend.api_num_spec_tokens is not None
|
||||
and self.backend.is_chat_model
|
||||
):
|
||||
self.backend.role_end_generate(self)
|
||||
self.cur_role = None
|
||||
|
||||
new_text = self.text_[self.cur_role_begin_pos :].lstrip()
|
||||
|
||||
_, suffix = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_)
|
||||
@@ -572,8 +603,6 @@ class StreamExecutor:
|
||||
# OpenAI chat API format
|
||||
self.messages_.append({"role": expr.role, "content": new_text})
|
||||
|
||||
self.cur_role = None
|
||||
|
||||
def _execute_var_scope_begin(self, expr: SglVarScopeBegin):
|
||||
self.variables[expr.name] = int(len(self.text_))
|
||||
|
||||
|
||||
@@ -31,8 +31,9 @@ class GenerateReqInput:
|
||||
|
||||
def post_init(self):
|
||||
|
||||
if ((self.text is None and self.input_ids is None) or
|
||||
(self.text is not None and self.input_ids is not None)):
|
||||
if (self.text is None and self.input_ids is None) or (
|
||||
self.text is not None and self.input_ids is not None
|
||||
):
|
||||
raise ValueError("Either text or input_ids should be provided.")
|
||||
|
||||
if self.text is not None:
|
||||
|
||||
@@ -38,7 +38,6 @@ from sglang.srt.utils import (
|
||||
)
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
|
||||
logger = logging.getLogger("model_rpc")
|
||||
vllm_default_logger.setLevel(logging.WARN)
|
||||
logging.getLogger("vllm.utils").setLevel(logging.WARN)
|
||||
|
||||
@@ -341,7 +341,6 @@ class TokenizerManager:
|
||||
return top_logprobs
|
||||
|
||||
|
||||
|
||||
global global_processor
|
||||
|
||||
|
||||
@@ -385,4 +384,4 @@ def get_pixel_values(
|
||||
pixel_values = pixel_values.astype(np.float16)
|
||||
return pixel_values, image_hash, image.size
|
||||
except Exception:
|
||||
print("Exception in TokenizerManager:\n" + get_exception_traceback())
|
||||
print("Exception in TokenizerManager:\n" + get_exception_traceback())
|
||||
|
||||
@@ -9,8 +9,8 @@ import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from typing import List, Optional, Union
|
||||
from http import HTTPStatus
|
||||
from typing import List, Optional, Union
|
||||
|
||||
# Fix a bug of Python threading
|
||||
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
||||
@@ -45,7 +45,6 @@ from sglang.srt.utils import (
|
||||
)
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
|
||||
@@ -84,6 +83,7 @@ async def flush_cache():
|
||||
|
||||
async def generate_request(obj: GenerateReqInput, request: Request):
|
||||
if obj.stream:
|
||||
|
||||
async def stream_results():
|
||||
try:
|
||||
async for out in tokenizer_manager.generate_request(obj, request):
|
||||
@@ -99,8 +99,10 @@ async def generate_request(obj: GenerateReqInput, request: Request):
|
||||
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
||||
return ret
|
||||
except ValueError as e:
|
||||
return JSONResponse({"error": {"message": str(e)}},
|
||||
status_code=HTTPStatus.BAD_REQUEST)
|
||||
return JSONResponse(
|
||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||
)
|
||||
|
||||
|
||||
app.post("/generate")(generate_request)
|
||||
app.put("/generate")(generate_request)
|
||||
|
||||
@@ -19,7 +19,6 @@ from packaging import version as pkg_version
|
||||
from pydantic import BaseModel
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -157,7 +156,9 @@ def allocate_init_ports(
|
||||
cur_port += 1
|
||||
|
||||
if port and ret_ports[0] != port:
|
||||
logger.warn(f"WARNING: Port {port} is not available. Use port {ret_ports[0]} instead.")
|
||||
logger.warn(
|
||||
f"WARNING: Port {port} is not available. Use port {ret_ports[0]} instead."
|
||||
)
|
||||
|
||||
return ret_ports[0], ret_ports[1:]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user