release initial code

Co-authored-by: Ying Sheng <sqy1415@gmail.com>
Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com>
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
Co-authored-by: parasol-aser <3848358+parasol-aser@users.noreply.github.com>
Co-authored-by: LiviaSun <33578456+ChuyueSun@users.noreply.github.com>
Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
Lianmin Zheng
2024-01-08 04:37:50 +00:00
parent f6d40df0ee
commit 22085081bb
145 changed files with 17802 additions and 2 deletions

View File

View File

@@ -0,0 +1,186 @@
from dataclasses import dataclass
from enum import Enum, auto
from typing import Callable, Dict, List, Tuple
class ChatTemplateStyle(Enum):
PLAIN = auto()
LLAMA2 = auto()
@dataclass
class ChatTemplate:
name: str
default_system_prompt: str
role_prefix_and_suffix: Dict[str, Tuple[str]]
image_token: str = "<image>"
style: ChatTemplateStyle = ChatTemplateStyle.PLAIN
def get_prefix_and_suffix(self, role, hist_messages):
if self.style == ChatTemplateStyle.PLAIN:
return self.role_prefix_and_suffix[role]
elif self.style == ChatTemplateStyle.LLAMA2:
if len(hist_messages) == 0 and role == "system":
return (
self.role_prefix_and_suffix["user"][0]
+ self.role_prefix_and_suffix["system"][0],
self.role_prefix_and_suffix["system"][1],
)
elif (
len(hist_messages) == 1
and role == "user"
and hist_messages[0]["content"] is not None
):
return ("", self.role_prefix_and_suffix["user"][1])
return self.role_prefix_and_suffix[role]
else:
raise ValueError(f"Invalid style: {self.style}")
def get_prompt(self, messages):
prompt = ""
for i in range(len(messages)):
role, content = messages[i]["role"], messages[i]["content"]
if role == "system" and content is None:
content = self.default_system_prompt
if content is None:
continue
prefix, suffix = self.get_prefix_and_suffix(role, messages[:i])
prompt += prefix + content + suffix
return prompt
chat_template_registry: Dict[str, ChatTemplate] = {}
matching_function_registry: List[Callable] = []
def register_chat_template(template):
chat_template_registry[template.name] = template
def register_chat_template_matching_function(func):
matching_function_registry.append(func)
def get_chat_template(name):
return chat_template_registry[name]
def get_chat_template_by_model_path(model_path):
for matching_func in matching_function_registry:
template = matching_func(model_path)
if template is not None:
return template
return get_chat_template("default")
register_chat_template(
ChatTemplate(
name="default",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("SYSTEM:", "\n"),
"user": ("USER:", "\n"),
"assistant": ("ASSISTANT:", "\n"),
},
)
)
register_chat_template(
ChatTemplate(
name="claude",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("", ""),
"user": ("\n\nHuman: ", ""),
"assistant": ("\n\nAssistant:", ""),
},
)
)
register_chat_template(
ChatTemplate(
name="chatml",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("<|im_start|>system\n", "\n<|im_end|>\n"),
"user": ("<|im_start|>user\n", "\n<|im_end|>\n"),
"assistant": ("<|im_start|>assistant\n", "\n<|im_end|>\n"),
},
style=ChatTemplateStyle.PLAIN,
)
)
register_chat_template(
ChatTemplate(
name="vicuna_v1.1",
default_system_prompt=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
role_prefix_and_suffix={
"system": ("", " "),
"user": ("USER:", " "),
"assistant": ("ASSISTANT:", "</s>"),
},
image_token=" <image>\n",
)
)
register_chat_template(
ChatTemplate(
name="llama-2-chat",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("<<SYS>>\n", "\n<</SYS>>\n\n"),
"user": ("[INST] ", " [/INST]"),
"assistant": ("", " </s><s>"),
},
style=ChatTemplateStyle.LLAMA2,
)
)
@register_chat_template_matching_function
def match_vicuna(model_path: str):
if "vicuna" in model_path.lower():
return get_chat_template("vicuna_v1.1")
if "llava" in model_path.lower():
return get_chat_template("vicuna_v1.1")
@register_chat_template_matching_function
def match_llama2_chat(model_path: str):
model_path = model_path.lower()
if "llama-2" in model_path and "chat" in model_path:
return get_chat_template("llama-2-chat")
if (
"mistral" in model_path or "mixtral" in model_path
) and "instruct" in model_path:
return get_chat_template("llama-2-chat")
if "codellama" in model_path and "instruct" in model_path:
return get_chat_template("llama-2-chat")
@register_chat_template_matching_function
def match_chat_ml(model_path: str):
if "tinyllama" in model_path.lower():
return get_chat_template("chatml")
if __name__ == "__main__":
messages = [
{"role": "system", "content": None}, # None means default
# {"role": "system", "content": "You are a helpful, respectful and honest assistant."},
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Hi!"},
{"role": "user", "content": "What can you do?"},
{"role": "assistant", "content": "I can chat with you."},
]
template = get_chat_template("llama-2-chat")
print(template.get_prompt(messages))

View File

@@ -0,0 +1,237 @@
import multiprocessing
from concurrent.futures import ThreadPoolExecutor
from queue import Queue
from typing import List, Union
from sglang.global_config import global_config
from sglang.lang.interpreter import ProgramState, StreamExecutor, pin_program
from sglang.lang.ir import (
SamplingParams,
SglArgument,
SglConstantText,
SglExpr,
SglVariable,
)
def compile_func(function, backend):
tracer = function.trace(backend=backend)
compiler = CompiledFunction(tracer, function)
return compiler
class CompiledFunction:
def __init__(self, tracer, function):
self.function = function
self.last_node = CompGraphNode(tracer.last_node)
self.expr_to_node = {}
self.build_graph(tracer)
self.topological_sort()
def build_graph(self, tracer):
self.nodes = [self.last_node]
self.expr_to_node[tracer.last_node] = self.nodes[-1]
rename_pid = {}
visited = set([tracer.last_node])
head = 0
while head < len(self.nodes):
cur_node = self.nodes[head]
# add prev node
prev_node = cur_node.expr.prev_node
if prev_node is not None:
if prev_node not in visited:
visited.add(prev_node)
self.nodes.append(CompGraphNode(prev_node))
self.expr_to_node[prev_node] = self.nodes[-1]
cur_node.prev_node = self.expr_to_node[prev_node]
self.expr_to_node[prev_node].add_next_node(cur_node)
# add source node
if isinstance(cur_node.expr, SglVariable):
if cur_node.expr.name in tracer.variables:
source = tracer.variables[cur_node.expr.name].source
else:
source = cur_node.expr.source
if source not in visited:
visited.add(source)
self.nodes.append(CompGraphNode(source))
self.expr_to_node[source] = self.nodes[-1]
cur_node.source_node = self.expr_to_node[source]
self.expr_to_node[source].add_next_node(cur_node)
head += 1
# rename pid
if cur_node.expr.pid not in rename_pid:
rename_pid[cur_node.expr.pid] = len(rename_pid)
cur_node.expr.pid = rename_pid[cur_node.expr.pid]
def topological_sort(self):
prevd = {}
cand = Queue()
for x in self.nodes:
prevd[x] = (x.prev_node is not None) + (x.source_node is not None)
if prevd[x] == 0:
cand.put(x)
new_list = []
while cand.qsize() > 0:
head = cand.get()
new_list.append(head)
for x in head.next_nodes:
prevd[x] -= 1
if prevd[x] == 0:
cand.put(x)
self.nodes = new_list
def print_graph(
self,
):
for node in self.nodes:
print(node)
def run_internal(
self,
backend,
kwargs,
default_sampling_para,
):
stream_executor_ids = set([x.expr.pid for x in self.nodes])
stream_executors = {}
for x in stream_executor_ids:
arguments = kwargs if x == self.last_node.expr.pid else {}
stream_executors[x] = StreamExecutor(
backend, arguments, default_sampling_para, None, False
)
for node in self.nodes:
se_id = node.expr.pid
expr = node.expr
if isinstance(expr, SglVariable):
# Make a copy for SglVariable
expr = SglVariable(expr.name, expr.source)
expr.source_stream_executor = stream_executors[
node.source_node.expr.pid
]
elif isinstance(expr, SglArgument):
# Substitute SglArgument
expr = kwargs[expr.name]
stream_executors[se_id].submit(expr)
for stream_executor in stream_executors.values():
stream_executor.end()
return ProgramState(stream_executors[self.last_node.expr.pid])
def run(
self,
*,
max_new_tokens: int = 16,
stop: Union[str, List[str]] = (),
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
backend=None,
**kwargs,
):
backend = backend or global_config.default_backend
kwargs = {k: SglArgument(k, v) for k, v in kwargs.items()}
kwargs.update(self.function.bind_arguments)
default_sampling_para = SamplingParams(
max_new_tokens=max_new_tokens,
stop=stop,
temperature=temperature,
top_p=top_p,
top_k=top_k,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
)
return self.run_internal(backend, kwargs, default_sampling_para)
def run_batch(
self,
batch_kwargs,
*,
max_new_tokens: int = 16,
stop: Union[str, List[str]] = (),
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
backend=None,
num_threads: Union[str, int] = "auto",
):
assert isinstance(batch_kwargs, (list, tuple))
if len(batch_kwargs) == 0:
return []
assert isinstance(batch_kwargs[0], dict)
backend = backend or global_config.default_backend
default_sampling_para = SamplingParams(
max_new_tokens=max_new_tokens,
stop=stop,
temperature=temperature,
top_p=top_p,
top_k=top_k,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
)
batch_kwargs = [
{k: SglArgument(k, v) for k, v in kwargs.items()} for kwargs in batch_kwargs
]
# Extract prefix by tracing and cache it
if len(batch_kwargs) > 1:
pin_program(self.function, backend)
# Run all programs
if num_threads == "auto":
num_threads = multiprocessing.cpu_count()
num_threads = min(num_threads, len(batch_kwargs))
if num_threads == 1:
rets = []
for arguments in batch_kwargs:
rets.append(
self.run_internal(backend, arguments, default_sampling_para)
)
else:
with ThreadPoolExecutor(num_threads) as executor:
futures = []
for arguments in batch_kwargs:
futures.append(
executor.submit(
self.run_internal, backend, arguments, default_sampling_para
)
)
rets = [f.result() for f in futures]
rets[-1].sync()
return rets
class CompGraphNode:
def __init__(
self, expr: SglExpr, prev_node=None, next_nodes=None, source_node=None
):
self.expr = expr
self.next_nodes = next_nodes or []
self.prev_node = prev_node
self.source_node = source_node
def add_next_node(self, other):
self.next_nodes.append(other)
def __repr__(self):
re = f"stream {self.expr.pid:2d}: "
re += f"%{self.expr.node_id} = "
if self.prev_node is not None:
re += f"%{self.prev_node.expr.node_id} + "
re += repr(self.expr)
return re

View File

@@ -0,0 +1,697 @@
"""The interpreter that executes SGL programs"""
import asyncio
import multiprocessing
import queue
import threading
import uuid
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from typing import Any, Callable, Dict, List, Optional, Union
import tqdm
from sglang.global_config import global_config
from sglang.lang.ir import (
SglArgument,
SglCommitLazy,
SglConcateAndAppend,
SglConstantText,
SglExpr,
SglExprList,
SglFunction,
SglGen,
SglImage,
SglRoleBegin,
SglRoleEnd,
SglSelect,
SglVariable,
SglVarScopeBegin,
SglVarScopeEnd,
)
from sglang.utils import encode_image_base64
def run_internal(state, program, func_args, func_kwargs, sync):
try:
state.ret_value = program.func(state, *func_args, **func_kwargs)
except Exception as e:
raise e
finally:
state.stream_executor.end()
if sync:
state.stream_executor.sync()
if global_config.verbosity >= 2:
print(state.text())
def run_program(
program, backend, func_args, func_kwargs, default_sampling_para, stream, sync=False
):
assert backend is not None, "Please specify a backend"
func_kwargs.update(program.bind_arguments)
stream_executor = StreamExecutor(
backend, func_kwargs, default_sampling_para, chat_template=None, stream=stream
)
state = ProgramState(stream_executor)
if stream:
t = threading.Thread(
target=run_internal, args=(state, program, func_args, func_kwargs, sync)
)
t.start()
return state
else:
run_internal(state, program, func_args, func_kwargs, sync)
return state
def run_program_batch(
program,
backend,
batch_arguments,
default_sampling_para,
num_threads,
progress_bar,
):
# Extract prefix by tracing and cache it
if len(batch_arguments) > 1:
pin_program(program, backend)
# Run all programs
if num_threads == "auto":
num_threads = multiprocessing.cpu_count()
num_threads = min(num_threads, len(batch_arguments))
if num_threads == 1:
rets = []
for arguments in batch_arguments:
rets.append(
run_program(
program, backend, (), arguments, default_sampling_para, False, False
)
)
else:
if progress_bar:
pbar = tqdm.tqdm(total=len(batch_arguments))
with ThreadPoolExecutor(num_threads) as executor:
futures = []
for arguments in batch_arguments:
futures.append(
executor.submit(
run_program,
program,
backend,
(),
arguments,
default_sampling_para,
False,
False,
)
)
if progress_bar:
futures[-1].add_done_callback(lambda _: pbar.update())
rets = [f.result() for f in futures]
rets[-1].sync()
if progress_bar:
pbar.close()
return rets
def pin_program(program, backend):
if global_config.enable_prefix_sharing and program.pin_prefix_rid is None:
# TODO: handle multiple backends
from sglang.lang.tracer import extract_prefix_by_tracing
prefix = extract_prefix_by_tracing(program, backend)
if prefix and len(prefix) > 64:
prefix_rid = backend.cache_prefix(prefix)
program.pin_prefix_rid = prefix_rid
return prefix_rid
return None
def unpin_program(program, backend):
pass
class StreamExecutor:
"""A stream executor that executes SGL expressions in a background thread."""
def __init__(
self,
backend,
arguments,
default_sampling_para,
chat_template,
stream,
use_thread=True,
):
self.sid = uuid.uuid4().hex
self.backend = backend
self.arguments: Dict[str, Any] = arguments
self.default_sampling_para = default_sampling_para
self.stream = stream
if hasattr(backend, "endpoint"):
self.backend = backend.endpoint
self.variables = {} # Dict[name: str -> value: str]
self.variable_event = {} # Dict[name: str -> event: threading.Event]
self.meta_info = {} # Dict[name: str -> info: str]
self.is_finished = False
# For completion
self.text_ = "" # The full text
# For chat
self.messages_ = [] # The messages in the OpenAI API format
self.chat_template = chat_template or self.backend.get_chat_template()
self.cur_role = None
self.cur_role_begin_pos = None
# For vision
self.images_ = []
self.cur_images = []
# For fork/join
self.fork_start_text_pos = None
# Worker thread
self.use_thread = use_thread
if self.use_thread:
self.queue = queue.Queue()
self.worker = threading.Thread(target=self._thread_worker_func)
self.worker.start()
# For streaming
if stream:
self.stream_text_event = threading.Event()
self.stream_var_event = {}
else:
self.stream_text_event = None
self.stream_var_event = None
def submit(self, expr: SglExpr):
if isinstance(expr, (SglGen, SglSelect, SglVarScopeBegin)):
self.variable_event[expr.name] = threading.Event()
if self.stream:
self.stream_var_event[expr.name] = threading.Event()
elif isinstance(expr, SglExprList):
for e in expr.expr_list:
if isinstance(e, (SglGen, SglSelect, SglVarScopeBegin)):
self.variable_event[e.name] = threading.Event()
if self.stream:
self.stream_var_event[e.name] = threading.Event()
if self.use_thread:
self.queue.put(expr)
else:
self._execute(expr)
def sync(self):
if self.use_thread:
self.queue.join()
def get_var(self, name):
if name in self.variable_event:
self.variable_event[name].wait()
return self.variables[name]
def get_meta_info(self, name):
if name in self.variable_event:
self.variable_event[name].wait()
ret = self.meta_info.get(name, None)
return ret
def fork(self, number: int, position_ids_offset: Optional[List[int]] = None):
if number > 1:
self.submit(SglCommitLazy())
self.sync()
number = int(number)
exes = [
StreamExecutor(
self.backend,
self.arguments,
self.default_sampling_para,
self.chat_template,
self.stream,
)
for _ in range(number)
]
for i in range(number):
exes[i].variables = dict(self.variables)
exes[i].text_ = str(self.text_)
exes[i].messages_ = list(self.messages_)
exes[i].cur_role = self.cur_role
exes[i].fork_start_text_pos = len(self.text_)
return exes
def text(self):
self.sync()
return self.text_
def messages(self):
self.sync()
return self.messages_
def end(self):
if self.use_thread:
if self.worker.is_alive():
self.queue.put(None)
self.backend.end_program(self)
def _thread_worker_func(self):
while True:
expr = self.queue.get()
if expr is None:
self.queue.task_done()
break
self._execute(expr)
self.queue.task_done()
if self.stream_text_event:
self.stream_text_event.set()
if self.stream_text_event:
self.stream_text_event.set()
self.is_finished = True
def _execute(self, other):
if isinstance(other, str):
other = SglConstantText(other)
assert isinstance(other, SglExpr), f"{other}"
if isinstance(other, (SglConstantText, SglArgument)):
self._execute_fill(other.value)
elif isinstance(other, SglGen):
self._execute_gen(other)
elif isinstance(other, SglSelect):
self._execute_select(other)
elif isinstance(other, SglExprList):
for x in other.expr_list:
self._execute(x)
elif isinstance(other, SglRoleBegin):
self._execute_role_begin(other)
elif isinstance(other, SglRoleEnd):
self._execute_role_end(other)
elif isinstance(other, SglImage):
self._execute_image(other)
elif isinstance(other, SglVariable):
self._execute_variable(other)
elif isinstance(other, SglVarScopeBegin):
self._execute_var_scope_begin(other)
elif isinstance(other, SglVarScopeEnd):
self._execute_var_scope_end(other)
elif isinstance(other, SglCommitLazy):
self._execute_commit_lazy_operations(other)
elif isinstance(other, SglConcateAndAppend):
if (
global_config.enable_parallel_encoding
and self.backend.support_concate_and_append
):
self._execute_concatenate_and_append_kv_cache(other)
else:
self._execute_concatenate_and_append_text(other)
else:
raise ValueError(f"Unknown type: {type(other)}")
def _execute_fill(self, value: str):
value = str(value)
self.text_ += value
def _execute_image(self, expr: SglImage):
path = expr.path
if isinstance(path, SglArgument):
path = path.value
base64_data = encode_image_base64(path)
self.images_.append((path, base64_data))
self.cur_images.append((path, base64_data))
self.text_ += self.chat_template.image_token
# if global_config.eager_fill_image:
# self.backend.fill_image(self)
def _execute_gen(self, expr: SglGen):
sampling_params = self._resolve_sampling_params(expr.sampling_params)
name = expr.name
if not self.stream:
comp, meta_info = self.backend.generate(
self, sampling_params=sampling_params
)
self.text_ += comp
self.variables[name] = comp
self.meta_info[name] = meta_info
self.variable_event[name].set()
else:
generator = self.backend.generate_stream(
self, sampling_params=sampling_params
)
self.stream_var_event[name].set()
self.variables[name] = ""
for comp, meta_info in generator:
self.text_ += comp
self.variables[name] += comp
self.stream_var_event[name].set()
self.stream_text_event.set()
self.meta_info[name] = meta_info
self.variable_event[name].set()
self.stream_var_event[name].set()
def _execute_select(self, expr: SglSelect):
decision, scores = self.backend.select(self, expr.choices, expr.temperature)
if expr.name is not None:
name = expr.name
self.variables[name] = decision
self.variable_event[name].set()
self.text_ += decision
def _execute_variable(self, expr: SglVariable):
src_executor = expr.source_stream_executor
value = src_executor.get_var(expr.name)
self._execute_fill(value)
def _execute_role_begin(self, expr: SglRoleBegin):
assert self.cur_role is None, "Nested roles are not allowed."
if len(self.messages_) == 0 and expr.role != "system":
# Insert the default system message
default_system = self.chat_template.default_system_prompt
if default_system:
self._execute_role_begin(SglRoleBegin("system"))
self._execute_fill(default_system)
self._execute_role_end(SglRoleEnd("system"))
self.cur_role = expr.role
prefix, _ = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_)
self._execute_fill(prefix)
self.cur_role_begin_pos = len(self.text_)
def _execute_role_end(self, expr: SglRoleEnd):
new_text = self.text_[self.cur_role_begin_pos :].lstrip()
_, suffix = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_)
self._execute_fill(suffix)
if self.cur_images:
# OpenAI vision API format
last_msg = {
"role": expr.role,
"content": [{"type": "text", "text": new_text}],
}
for (image_path, image_base64_data) in self.cur_images:
last_msg["content"].append(
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_base64_data}"
},
}
)
self.messages_.append(last_msg)
self.cur_images = []
else:
self.messages_.append({"role": expr.role, "content": new_text})
self.cur_role = None
def _execute_var_scope_begin(self, expr: SglVarScopeBegin):
self.variables[expr.name] = int(len(self.text_))
def _execute_var_scope_end(self, expr: SglVarScopeEnd):
self.variables[expr.name] = self.text_[self.variables[expr.name] :]
self.variable_event[expr.name].set()
def _execute_commit_lazy_operations(self, expr: SglCommitLazy):
self.backend.commit_lazy_operations(self)
def _execute_concatenate_and_append_text(self, expr: SglConcateAndAppend):
new_text = ""
for s in expr.states:
exe = s.stream_executor
exe.sync()
new_text += exe.text_[exe.fork_start_text_pos :]
self._execute_fill(new_text)
def _execute_concatenate_and_append_kv_cache(self, expr: SglConcateAndAppend):
self_len = len(self.text_)
for i, s in enumerate(expr.states):
exe = s.stream_executor
exe.submit(SglCommitLazy())
for i, s in enumerate(expr.states):
exe = s.stream_executor
exe.sync()
assert exe.fork_start_text_pos == self_len
self.text_ += exe.text_[exe.fork_start_text_pos :]
src_rids = [state.stream_executor.sid for state in expr.states]
self.backend.concatenate_and_append(src_rids, self.sid)
def _resolve_sampling_params(self, sampling_params):
clone = None
for item in [
"max_new_tokens",
"stop",
"temperature",
"top_p",
"top_k",
"frequency_penalty",
"presence_penalty",
"dtype",
"regex",
]:
value = getattr(sampling_params, item, None)
if value is not None:
if clone is None:
clone = self.default_sampling_para.clone()
setattr(clone, item, value)
return clone or self.default_sampling_para
def __del__(self):
self.end()
class ProgramState:
"""The state of an SGL program."""
def __init__(self, stream_executor: StreamExecutor):
self.stream_executor = stream_executor
def _role_common(self, name: str, expr: Optional[SglExpr] = None):
if expr is not None:
self.stream_executor.submit(
SglExprList([SglRoleBegin(name), expr, SglRoleEnd(name)])
)
else:
@contextmanager
def role_scope():
self.stream_executor.submit(SglRoleBegin(name))
yield
self.stream_executor.submit(SglRoleEnd(name))
return role_scope()
def system(self, expr: Optional[SglExpr] = None):
return self._role_common("system", expr)
def user(self, expr: Optional[SglExpr] = None):
return self._role_common("user", expr)
def assistant(self, expr: Optional[SglExpr] = None):
return self._role_common("assistant", expr)
@contextmanager
def var_scope(self, name: str):
self.stream_executor.submit(SglVarScopeBegin(name))
yield
self.stream_executor.submit(SglVarScopeEnd(name))
def fork(self, number: int = 1, position_ids_offset: Optional[List[int]] = None):
stream_executors = self.stream_executor.fork(number, position_ids_offset)
states = [ProgramState(x) for x in stream_executors]
state_group = ProgramStateGroup(states, self)
return state_group
@contextmanager
def copy(self, position_ids_offset: Optional[List[int]] = None):
state_group = self.fork(1, position_ids_offset)
try:
yield state_group[0]
finally:
state_group.join()
def text(self):
return self.stream_executor.text()
def messages(self):
return self.stream_executor.messages()
def sync(self):
return self.stream_executor.sync()
def text_iter(self, var_name=None):
if self.stream_executor.stream:
prev = 0
if var_name is None:
event = self.stream_executor.stream_text_event
while True:
event.wait()
event.clear()
out = str(self.stream_executor.text_[prev:])
prev += len(out)
if out:
yield out
if self.stream_executor.is_finished:
break
else:
event = self.stream_executor.stream_var_event[var_name]
while True:
event.wait()
event.clear()
out = str(self.stream_executor.variables[var_name][prev:])
prev += len(out)
if out:
yield out
if self.stream_executor.variable_event[var_name].is_set():
break
else:
if var_name is None:
yield self.text()
else:
yield self.get_var(name)
async def text_async_iter(self, var_name=None):
loop = asyncio.get_running_loop()
if self.stream_executor.stream:
prev = 0
if var_name is None:
event = self.stream_executor.stream_text_event
while True:
await loop.run_in_executor(None, event.wait)
event.clear()
out = str(self.stream_executor.text_[prev:])
prev += len(out)
if out:
yield out
if self.stream_executor.is_finished:
break
else:
event = self.stream_executor.stream_var_event[var_name]
while True:
await loop.run_in_executor(None, event.wait)
event.clear()
out = str(self.stream_executor.variables[var_name][prev:])
prev += len(out)
if out:
yield out
if self.stream_executor.variable_event[var_name].is_set():
break
else:
if var_name is None:
yield self.text()
else:
yield self.get_var(name)
def get_var(self, name):
return self.stream_executor.get_var(name)
def get_meta_info(self, name):
return self.stream_executor.get_meta_info(name)
def __iadd__(self, other):
self.stream_executor.submit(other)
return self
def __getitem__(self, name):
return self.get_var(name)
def __del__(self):
self.stream_executor.end()
def __repr__(self) -> str:
msgs = self.messages()
ret = ""
for msg in msgs:
ret += msg["role"] + ":\n" + msg["content"] + "\n"
return ret
class ProgramStateGroup:
def __init__(
self, states: List[ProgramState], src_state: Optional[ProgramState] = None
):
self.states = states
self.src_state = src_state
def join(self, mode: str = "gather_variable"):
if mode == "gather_variable":
# Copy variables back
src_vars = self.src_state.stream_executor.variables
src_var_set = set(src_vars.keys())
for child_state in self.states:
child_state.stream_executor.sync()
child_vars = child_state.stream_executor.variables
new_vars = set(child_vars.keys()) - src_var_set
for k in new_vars:
if k in src_vars:
src_vars[k].append(child_vars[k])
else:
src_vars[k] = [child_vars[k]]
elif mode == "concate_and_append":
# Concatenate and append KV cache
self.src_state += SglConcateAndAppend(self.states)
# Need a sync here. Otherwise, `states` can be deleted.
self.src_state.stream_executor.sync()
else:
raise ValueError(f"Invalid join mode: {mode}")
for s in self.states:
s.stream_executor.end()
def __getitem__(self, i: int):
return self.states[i]
def __setitem__(self, i: int, value):
assert self.states[i] == value
def __iadd__(self, other):
if isinstance(other, Callable):
# lambda function
for i in range(len(self.states)):
self.states[i] += other(i)
elif isinstance(other, SglExpr):
for i in range(len(self.states)):
self.states[i] += other
elif isinstance(other, (list, tuple)):
for i in range(len(self.states)):
self.states[i] += other[i]
else:
raise ValueError(f"Invalid value: {other}")
return self

442
python/sglang/lang/ir.py Normal file
View File

@@ -0,0 +1,442 @@
"""The intermediate representation."""
import dataclasses
import inspect
from typing import List, Optional, Union
from sglang.global_config import global_config
@dataclasses.dataclass
class SamplingParams:
max_new_tokens: int = 16
stop: Union[str, List[str]] = ()
temperature: float = 1.0
top_p: float = 1.0
top_k: int = -1 # -1 means disable
frequency_penalty: float = 0.0
presence_penalty: float = 0.0
# for constrained generation, not included in to_xxx_kwargs
dtype: Optional[str] = None
regex: Optional[str] = None
def clone(self):
return SamplingParams(
self.max_new_tokens,
self.stop,
self.temperature,
self.top_p,
self.top_k,
self.frequency_penalty,
self.presence_penalty,
)
def to_openai_kwargs(self):
# OpenAI does not support top_k, so we drop it here
return {
"max_tokens": self.max_new_tokens,
"stop": self.stop or None,
"temperature": self.temperature,
"top_p": self.top_p,
"frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty,
}
def to_anthropic_kwargs(self):
# Anthropic does not support frequency_penalty or presence_penalty, so we drop it here
return {
"max_tokens_to_sample": self.max_new_tokens,
"stop_sequences": self.stop,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
}
def to_srt_kwargs(self):
return {
"max_new_tokens": self.max_new_tokens,
"stop": self.stop,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty,
"regex": self.regex,
}
class SglFunction:
def __init__(self, func, bind_arguments=None):
self.func = func
self.bind_arguments = bind_arguments or {}
self.pin_prefix_rid = None
# Parse arguments
argspec = inspect.getfullargspec(func)
assert argspec.args[0] == "s", 'The first argument must be "s"'
self.arg_names = argspec.args[1:]
def bind(self, **kwargs):
assert all(key in self.arg_names for key in kwargs)
new_bind_dict = {**self.bind_arguments, **kwargs}
return SglFunction(self.func, bind_arguments=new_bind_dict)
def run(
self,
*args,
max_new_tokens: int = 16,
stop: Union[str, List[str]] = (),
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
stream: bool = False,
backend=None,
**kwargs,
):
from sglang.lang.interpreter import run_program
default_sampling_para = SamplingParams(
max_new_tokens=max_new_tokens,
stop=stop,
temperature=temperature,
top_p=top_p,
top_k=top_k,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
)
backend = backend or global_config.default_backend
kwargs = {k: SglArgument(k, v) for k, v in kwargs.items()}
return run_program(self, backend, args, kwargs, default_sampling_para, stream)
def run_batch(
self,
batch_kwargs,
*,
max_new_tokens: int = 16,
stop: Union[str, List[str]] = (),
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
backend=None,
num_threads: Union[str, int] = "auto",
progress_bar: bool = False,
):
from sglang.lang.interpreter import run_program_batch
assert isinstance(batch_kwargs, (list, tuple))
if len(batch_kwargs) == 0:
return []
assert isinstance(batch_kwargs[0], dict)
default_sampling_para = SamplingParams(
max_new_tokens=max_new_tokens,
stop=stop,
temperature=temperature,
top_p=top_p,
top_k=top_k,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
)
backend = backend or global_config.default_backend
batch_kwargs = [
{k: SglArgument(k, v) for k, v in kwargs.items()} for kwargs in batch_kwargs
]
return run_program_batch(
self,
backend,
batch_kwargs,
default_sampling_para,
num_threads,
progress_bar,
)
def trace(self, *, backend=None, **kwargs):
from sglang.lang.tracer import trace_program
backend = backend or global_config.default_backend
return trace_program(self, kwargs, backend)
def pin(self, backend=None):
from sglang.lang.interpreter import pin_program
backend = backend or global_config.default_backend
return pin_program(self, backend)
def unpin(self, backend=None):
from sglang.lang.interpreter import unpin_program
backend = backend or global_config.default_backend
return unpin_program(self, backend)
def compile(self, *, backend=None):
from sglang.lang.compiler import compile_func
return compile_func(self, backend)
def __call__(self, *args, **kwargs):
from sglang.lang.tracer import TracingScope
tracing_scope = TracingScope.get_current_scope()
if tracing_scope is None:
return self.run(*args, **kwargs)
else:
kwargs["backend"] = tracing_scope.tracer_state.backend
return self.trace(*args, **kwargs)
class SglExpr:
node_ct = 0
def __init__(self):
self.node_id = SglExpr.node_ct
self.prev_node = None
self.pid = None
SglExpr.node_ct += 1
def __add__(self, other):
if isinstance(other, str):
other = SglConstantText(other)
assert isinstance(other, SglExpr)
return self.concatenate_ir(self, other)
def __radd__(self, other):
if isinstance(other, str):
other = SglConstantText(other)
assert isinstance(other, SglExpr), f"{other}"
return self.concatenate_ir(other, self)
def concatenate_ir(self, a, b):
if isinstance(a, SglExprList):
if isinstance(b, SglExprList):
return SglExprList(a.expr_list + b.expr_list)
else:
return SglExprList(a.expr_list + [b])
elif isinstance(b, SglExprList):
return SglExprList([a] + b.expr_list)
return SglExprList([a, b])
def print_graph_dfs(self):
ret = [""]
visited = set()
def dfs_print(x):
if x is None or x in visited:
return
visited.add(x)
# Print dependency
if x.prev_node is not None:
dfs_print(x.prev_node)
if isinstance(x, SglExprList):
for y in x.expr_list:
dfs_print(y)
# elif isinstance(x, SglRole):
# dfs_print(x.expr)
elif isinstance(x, SglVariable):
dfs_print(x.source)
# Print the node itself
if isinstance(x, (SglFork, SglGetForkItem)):
ret[0] += f"%{x.node_id} = {x}\n"
else:
if x.prev_node is not None:
ret[0] += (
f"%{x.node_id} = %{x.prev_node.node_id} + " + str(x) + "\n"
)
else:
ret[0] += f"%{x.node_id} = " + str(x) + "\n"
dfs_print(self)
return ret[0]
class SglExprList(SglExpr):
def __init__(self, expr_list: List[SglExpr]):
super().__init__()
self.expr_list = expr_list
def __repr__(self):
return f"ExprList({self.expr_list})"
class SglArgument(SglExpr):
def __init__(self, name: str, value: str):
super().__init__()
self.name = name
self.value = value
def __repr__(self):
return f"Argument(name={self.name}, value={repr(self.value)})"
def __len__(self):
return len(self.value)
def __getitem__(self, i):
return self.value[i]
def __int__(self):
return self.value
def __bool__(self):
return self.value
def __format__(self, *args):
raise TypeError(
"Cannot put argument inside a f-string. "
"This is not compatible with the tracer. "
)
class SglImage(SglExpr):
def __init__(self, path):
self.path = path
def __repr__(self) -> str:
return f"SglImage({self.path})"
class SglGen(SglExpr):
def __init__(
self,
name,
max_new_tokens,
stop,
temperature,
top_p,
top_k,
frequency_penalty,
presence_penalty,
dtype,
regex,
):
super().__init__()
self.name = name
self.sampling_params = SamplingParams(
max_new_tokens=max_new_tokens,
stop=stop,
temperature=temperature,
top_p=top_p,
top_k=top_k,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
dtype=dtype,
regex=regex,
)
def __repr__(self):
return f"Gen('{self.name}')"
class SglConstantText(SglExpr):
def __init__(self, value):
super().__init__()
self.value = value
def __repr__(self):
return f"Constant({repr(self.value)})"
class SglRoleBegin(SglExpr):
def __init__(self, role):
super().__init__()
self.role = role
def __repr__(self):
return f"RoleBegin({self.role})"
class SglRoleEnd(SglExpr):
def __init__(self, role):
super().__init__()
self.role = role
def __repr__(self):
return f"RoleEnd({self.role})"
class SglSelect(SglExpr):
def __init__(self, name, choices, temperature):
super().__init__()
self.name = name
self.choices = choices
self.temperature = temperature
def __repr__(self):
return f"Select({self.name}, choices={self.choices})"
class SglFork(SglExpr):
def __init__(self, number, position_ids_offset=None):
super().__init__()
self.number = number
self.position_ids_offset = position_ids_offset
def __repr__(self):
return (
f"Fork(%{self.prev_node.node_id}, number={self.number}, "
f"position_ids_offset={self.position_ids_offset})"
)
class SglGetForkItem(SglExpr):
def __init__(self, index):
super().__init__()
self.index = index
def __repr__(self):
return f"GetForkItem(%{self.prev_node.node_id}, index={self.index})"
class SglVariable(SglExpr):
def __init__(self, name, source):
super().__init__()
self.name = name
self.source = source
def __repr__(self):
return f"Variable('{self.name}', source=%{self.source.node_id})"
class SglVarScopeBegin(SglExpr):
def __init__(self, name):
super().__init__()
self.name = name
def __repr__(self):
return f"VarScopeBegin('{self.name}')"
class SglVarScopeEnd(SglExpr):
def __init__(self, name):
super().__init__()
self.name = name
def __repr__(self):
return f"VarScopeEnd('{self.name}')"
class SglConcateAndAppend(SglExpr):
def __init__(self, states):
super().__init__()
self.states = states
def __repr__(self):
return f"ConcatenateAndAppend('{self.states}')"
class SglCommitLazy(SglExpr):
def __init__(self):
super().__init__()
def __repr__(self):
return f"CommitLazy()"

View File

@@ -0,0 +1,279 @@
"""Tracing a program."""
import uuid
from typing import Any, Callable, Dict, List, Optional, Union
from sglang.backend.base_backend import BaseBackend
from sglang.global_config import global_config
from sglang.lang.interpreter import ProgramState, ProgramStateGroup
from sglang.lang.ir import (
SglArgument,
SglCommitLazy,
SglConcateAndAppend,
SglConstantText,
SglExpr,
SglExprList,
SglFork,
SglFunction,
SglGen,
SglGetForkItem,
SglRoleBegin,
SglRoleEnd,
SglSelect,
SglVariable,
SglVarScopeBegin,
SglVarScopeEnd,
)
class StopTracing(Exception):
pass
def extract_prefix_by_tracing(program, backend):
# Create dummy arguments
dummy_arguments = {name: SglArgument(name, None) for name in program.arg_names}
arguments = dummy_arguments
arguments.update(program.bind_arguments)
# Trace
tracer = TracerProgramState(backend, arguments, only_trace_prefix=True)
try:
with TracingScope(tracer):
tracer.ret_value = program.func(tracer, **arguments)
except StopTracing:
pass
# Run and cache prefix
prefix = ""
for expr in tracer.flatten_nodes():
if isinstance(expr, SglConstantText):
prefix += expr.value
else:
break
return prefix
def trace_program(program, arguments, backend):
# Create dummy backend
if backend is None:
backend = BaseBackend()
# Create dummy arguments
dummy_arguments = {
name: SglArgument(name, None)
for name in program.arg_names
if name not in arguments
}
arguments.update(dummy_arguments)
arguments.update(program.bind_arguments)
# Trace
tracer = TracerProgramState(backend, arguments, only_trace_prefix=False)
with TracingScope(tracer):
tracer.ret_value = program.func(tracer, **arguments)
return tracer
class TracerProgramState(ProgramState):
def __init__(self, backend, arguments, only_trace_prefix):
self.pid = uuid.uuid4().hex
self.backend = backend
self.arguments: Dict[str, Any] = arguments
self.only_trace_prefix = only_trace_prefix
if hasattr(backend, "endpoint"):
self.backend = backend.endpoint
self.nodes = []
self.last_node = None
self.variables = {}
self.ret_value = None
# For completion
# For chat
self.messages_ = []
self.cur_role = None
self.chat_template = self.backend.get_chat_template()
# For multi states
self.child_states = []
cur_scope = TracingScope.get_current_scope()
if cur_scope is not None:
cur_scope.add_child_state(self)
##################################
########### Public API ###########
##################################
def fork(self, number: int, position_ids_offset: Optional[List[int]] = None):
if self.only_trace_prefix:
raise StopTracing()
fork_node = SglFork(number)
fork_node.prev_node = self.last_node
states = [
TracerProgramState(self.backend, self.arguments, self.only_trace_prefix)
for _ in range(number)
]
for i in range(number):
node = SglGetForkItem(i)
node.prev_node = fork_node
states[i].last_node = node
states[i].variables = dict(self.variables)
states[i].messages_ = list(self.messages_)
states[i].cur_role = self.cur_role
states[i].chat_template = self.chat_template
state_group = ProgramStateGroup(states, self)
return state_group
##################################
########## Internal API ##########
##################################
def _append_node(self, other: SglExpr):
self.nodes.append(other)
other.prev_node = self.last_node
self.last_node = other
def _execute(self, other: SglExpr):
if isinstance(other, str):
other = SglConstantText(other)
other.pid = self.pid
if isinstance(other, SglConstantText):
self._execute_fill(other)
elif isinstance(other, SglGen):
self._execute_gen(other)
elif isinstance(other, SglSelect):
self._execute_select(other)
elif isinstance(other, SglExprList):
for x in other.expr_list:
self._execute(x)
elif isinstance(other, SglRoleBegin):
self._execute_role_begin(other)
elif isinstance(other, SglRoleEnd):
self._execute_role_end(other)
elif isinstance(other, SglVarScopeBegin):
self._execute_var_scope_begin(other)
elif isinstance(other, SglVarScopeEnd):
self._execute_var_scope_end(other)
else:
if self.only_trace_prefix:
raise StopTracing()
else:
self._append_node(other)
return self
def __iadd__(self, other):
self._execute(other)
return self
def _execute_fill(self, expr: SglConstantText):
if isinstance(expr, str):
expr = SglConstantText(expr)
self._append_node(expr)
def _execute_gen(self, expr: SglGen):
name = expr.name if expr.name is not None else "gen_" + str(len(self.variables))
new_node = SglVariable(name, source=expr)
self.variables[name] = new_node
self._append_node(expr)
def _execute_select(self, expr: SglSelect):
name = (
expr.name if expr.name is not None else "select_" + str(len(self.variables))
)
new_node = SglVariable(name, source=expr)
self.variables[name] = new_node
self._append_node(expr)
def _execute_role_begin(self, expr: SglRoleBegin):
assert self.cur_role is None, "Nested roles are not allowed."
if len(self.messages_) == 0 and expr.role != "system":
# Insert default system message
default_system = self.chat_template.default_system_prompt
if default_system:
self._execute_role_begin(SglRoleBegin("system"))
self._execute_fill(default_system)
self._execute_role_end(SglRoleEnd("system"))
self.cur_role = expr.role
prefix, suffix = self.chat_template.get_prefix_and_suffix(
expr.role, self.messages_
)
self._execute_fill(prefix)
def _execute_role_end(self, expr: SglRoleEnd):
prefix, suffix = self.chat_template.get_prefix_and_suffix(
expr.role, self.messages_
)
self._execute_fill(suffix)
self.messages_.append({"role": expr.role, "content": ""})
self.cur_role = None
def _execute_var_scope_end(self, expr: SglVarScopeEnd):
new_node = SglVariable(name, source=self.last_node)
self.variables[name] = new_node
def get_var(self, name):
ret = self.arguments.get(name, None)
if ret is not None:
return ret
v = self.variables[name]
return SglVariable(v.name, v.source)
def flatten_nodes(self):
def traverse(cur):
if isinstance(cur, SglExprList):
for child in cur.expr_list:
traverse(child)
else:
ret.append(cur)
ret = []
for x in self.nodes:
traverse(x)
return ret
def __del__(self):
pass
class TracingScope:
cur_scope = None
def __init__(self, tracer_state: TracerProgramState):
self.tracer_state = tracer_state
self.last_scope = TracingScope.cur_scope
def __enter__(self):
TracingScope.cur_scope = self
return self
def __exit__(self, exc_type, exc_value, traceback):
TracingScope.cur_scope = self.last_scope
@staticmethod
def get_current_scope():
return TracingScope.cur_scope
def add_child_state(self, state: TracerProgramState):
cur_scope = self
while cur_scope != None:
cur_scope.tracer_state.child_states.append(state)
cur_scope = cur_scope.last_scope