release initial code
Co-authored-by: Ying Sheng <sqy1415@gmail.com> Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com> Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu> Co-authored-by: parasol-aser <3848358+parasol-aser@users.noreply.github.com> Co-authored-by: LiviaSun <33578456+ChuyueSun@users.noreply.github.com> Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
0
python/sglang/lang/__init__.py
Normal file
0
python/sglang/lang/__init__.py
Normal file
186
python/sglang/lang/chat_template.py
Normal file
186
python/sglang/lang/chat_template.py
Normal file
@@ -0,0 +1,186 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import Callable, Dict, List, Tuple
|
||||
|
||||
|
||||
class ChatTemplateStyle(Enum):
|
||||
PLAIN = auto()
|
||||
LLAMA2 = auto()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatTemplate:
|
||||
name: str
|
||||
default_system_prompt: str
|
||||
role_prefix_and_suffix: Dict[str, Tuple[str]]
|
||||
image_token: str = "<image>"
|
||||
style: ChatTemplateStyle = ChatTemplateStyle.PLAIN
|
||||
|
||||
def get_prefix_and_suffix(self, role, hist_messages):
|
||||
if self.style == ChatTemplateStyle.PLAIN:
|
||||
return self.role_prefix_and_suffix[role]
|
||||
elif self.style == ChatTemplateStyle.LLAMA2:
|
||||
if len(hist_messages) == 0 and role == "system":
|
||||
return (
|
||||
self.role_prefix_and_suffix["user"][0]
|
||||
+ self.role_prefix_and_suffix["system"][0],
|
||||
self.role_prefix_and_suffix["system"][1],
|
||||
)
|
||||
elif (
|
||||
len(hist_messages) == 1
|
||||
and role == "user"
|
||||
and hist_messages[0]["content"] is not None
|
||||
):
|
||||
return ("", self.role_prefix_and_suffix["user"][1])
|
||||
return self.role_prefix_and_suffix[role]
|
||||
else:
|
||||
raise ValueError(f"Invalid style: {self.style}")
|
||||
|
||||
def get_prompt(self, messages):
|
||||
prompt = ""
|
||||
for i in range(len(messages)):
|
||||
role, content = messages[i]["role"], messages[i]["content"]
|
||||
if role == "system" and content is None:
|
||||
content = self.default_system_prompt
|
||||
if content is None:
|
||||
continue
|
||||
|
||||
prefix, suffix = self.get_prefix_and_suffix(role, messages[:i])
|
||||
prompt += prefix + content + suffix
|
||||
return prompt
|
||||
|
||||
|
||||
chat_template_registry: Dict[str, ChatTemplate] = {}
|
||||
matching_function_registry: List[Callable] = []
|
||||
|
||||
|
||||
def register_chat_template(template):
|
||||
chat_template_registry[template.name] = template
|
||||
|
||||
|
||||
def register_chat_template_matching_function(func):
|
||||
matching_function_registry.append(func)
|
||||
|
||||
|
||||
def get_chat_template(name):
|
||||
return chat_template_registry[name]
|
||||
|
||||
|
||||
def get_chat_template_by_model_path(model_path):
|
||||
for matching_func in matching_function_registry:
|
||||
template = matching_func(model_path)
|
||||
if template is not None:
|
||||
return template
|
||||
return get_chat_template("default")
|
||||
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="default",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": ("SYSTEM:", "\n"),
|
||||
"user": ("USER:", "\n"),
|
||||
"assistant": ("ASSISTANT:", "\n"),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="claude",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": ("", ""),
|
||||
"user": ("\n\nHuman: ", ""),
|
||||
"assistant": ("\n\nAssistant:", ""),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="chatml",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": ("<|im_start|>system\n", "\n<|im_end|>\n"),
|
||||
"user": ("<|im_start|>user\n", "\n<|im_end|>\n"),
|
||||
"assistant": ("<|im_start|>assistant\n", "\n<|im_end|>\n"),
|
||||
},
|
||||
style=ChatTemplateStyle.PLAIN,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="vicuna_v1.1",
|
||||
default_system_prompt=(
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
||||
),
|
||||
role_prefix_and_suffix={
|
||||
"system": ("", " "),
|
||||
"user": ("USER:", " "),
|
||||
"assistant": ("ASSISTANT:", "</s>"),
|
||||
},
|
||||
image_token=" <image>\n",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="llama-2-chat",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": ("<<SYS>>\n", "\n<</SYS>>\n\n"),
|
||||
"user": ("[INST] ", " [/INST]"),
|
||||
"assistant": ("", " </s><s>"),
|
||||
},
|
||||
style=ChatTemplateStyle.LLAMA2,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_vicuna(model_path: str):
|
||||
if "vicuna" in model_path.lower():
|
||||
return get_chat_template("vicuna_v1.1")
|
||||
if "llava" in model_path.lower():
|
||||
return get_chat_template("vicuna_v1.1")
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_llama2_chat(model_path: str):
|
||||
model_path = model_path.lower()
|
||||
if "llama-2" in model_path and "chat" in model_path:
|
||||
return get_chat_template("llama-2-chat")
|
||||
if (
|
||||
"mistral" in model_path or "mixtral" in model_path
|
||||
) and "instruct" in model_path:
|
||||
return get_chat_template("llama-2-chat")
|
||||
if "codellama" in model_path and "instruct" in model_path:
|
||||
return get_chat_template("llama-2-chat")
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_chat_ml(model_path: str):
|
||||
if "tinyllama" in model_path.lower():
|
||||
return get_chat_template("chatml")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
messages = [
|
||||
{"role": "system", "content": None}, # None means default
|
||||
# {"role": "system", "content": "You are a helpful, respectful and honest assistant."},
|
||||
{"role": "user", "content": "Hello!"},
|
||||
{"role": "assistant", "content": "Hi!"},
|
||||
{"role": "user", "content": "What can you do?"},
|
||||
{"role": "assistant", "content": "I can chat with you."},
|
||||
]
|
||||
|
||||
template = get_chat_template("llama-2-chat")
|
||||
print(template.get_prompt(messages))
|
||||
237
python/sglang/lang/compiler.py
Normal file
237
python/sglang/lang/compiler.py
Normal file
@@ -0,0 +1,237 @@
|
||||
import multiprocessing
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from queue import Queue
|
||||
from typing import List, Union
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.lang.interpreter import ProgramState, StreamExecutor, pin_program
|
||||
from sglang.lang.ir import (
|
||||
SamplingParams,
|
||||
SglArgument,
|
||||
SglConstantText,
|
||||
SglExpr,
|
||||
SglVariable,
|
||||
)
|
||||
|
||||
|
||||
def compile_func(function, backend):
|
||||
tracer = function.trace(backend=backend)
|
||||
compiler = CompiledFunction(tracer, function)
|
||||
return compiler
|
||||
|
||||
|
||||
class CompiledFunction:
|
||||
def __init__(self, tracer, function):
|
||||
self.function = function
|
||||
|
||||
self.last_node = CompGraphNode(tracer.last_node)
|
||||
self.expr_to_node = {}
|
||||
self.build_graph(tracer)
|
||||
self.topological_sort()
|
||||
|
||||
def build_graph(self, tracer):
|
||||
self.nodes = [self.last_node]
|
||||
self.expr_to_node[tracer.last_node] = self.nodes[-1]
|
||||
|
||||
rename_pid = {}
|
||||
|
||||
visited = set([tracer.last_node])
|
||||
head = 0
|
||||
while head < len(self.nodes):
|
||||
cur_node = self.nodes[head]
|
||||
|
||||
# add prev node
|
||||
prev_node = cur_node.expr.prev_node
|
||||
if prev_node is not None:
|
||||
if prev_node not in visited:
|
||||
visited.add(prev_node)
|
||||
self.nodes.append(CompGraphNode(prev_node))
|
||||
self.expr_to_node[prev_node] = self.nodes[-1]
|
||||
cur_node.prev_node = self.expr_to_node[prev_node]
|
||||
self.expr_to_node[prev_node].add_next_node(cur_node)
|
||||
|
||||
# add source node
|
||||
if isinstance(cur_node.expr, SglVariable):
|
||||
if cur_node.expr.name in tracer.variables:
|
||||
source = tracer.variables[cur_node.expr.name].source
|
||||
else:
|
||||
source = cur_node.expr.source
|
||||
if source not in visited:
|
||||
visited.add(source)
|
||||
self.nodes.append(CompGraphNode(source))
|
||||
self.expr_to_node[source] = self.nodes[-1]
|
||||
cur_node.source_node = self.expr_to_node[source]
|
||||
self.expr_to_node[source].add_next_node(cur_node)
|
||||
head += 1
|
||||
|
||||
# rename pid
|
||||
if cur_node.expr.pid not in rename_pid:
|
||||
rename_pid[cur_node.expr.pid] = len(rename_pid)
|
||||
cur_node.expr.pid = rename_pid[cur_node.expr.pid]
|
||||
|
||||
def topological_sort(self):
|
||||
prevd = {}
|
||||
cand = Queue()
|
||||
for x in self.nodes:
|
||||
prevd[x] = (x.prev_node is not None) + (x.source_node is not None)
|
||||
if prevd[x] == 0:
|
||||
cand.put(x)
|
||||
new_list = []
|
||||
while cand.qsize() > 0:
|
||||
head = cand.get()
|
||||
new_list.append(head)
|
||||
for x in head.next_nodes:
|
||||
prevd[x] -= 1
|
||||
if prevd[x] == 0:
|
||||
cand.put(x)
|
||||
self.nodes = new_list
|
||||
|
||||
def print_graph(
|
||||
self,
|
||||
):
|
||||
for node in self.nodes:
|
||||
print(node)
|
||||
|
||||
def run_internal(
|
||||
self,
|
||||
backend,
|
||||
kwargs,
|
||||
default_sampling_para,
|
||||
):
|
||||
stream_executor_ids = set([x.expr.pid for x in self.nodes])
|
||||
stream_executors = {}
|
||||
for x in stream_executor_ids:
|
||||
arguments = kwargs if x == self.last_node.expr.pid else {}
|
||||
stream_executors[x] = StreamExecutor(
|
||||
backend, arguments, default_sampling_para, None, False
|
||||
)
|
||||
for node in self.nodes:
|
||||
se_id = node.expr.pid
|
||||
expr = node.expr
|
||||
if isinstance(expr, SglVariable):
|
||||
# Make a copy for SglVariable
|
||||
expr = SglVariable(expr.name, expr.source)
|
||||
expr.source_stream_executor = stream_executors[
|
||||
node.source_node.expr.pid
|
||||
]
|
||||
elif isinstance(expr, SglArgument):
|
||||
# Substitute SglArgument
|
||||
expr = kwargs[expr.name]
|
||||
stream_executors[se_id].submit(expr)
|
||||
for stream_executor in stream_executors.values():
|
||||
stream_executor.end()
|
||||
return ProgramState(stream_executors[self.last_node.expr.pid])
|
||||
|
||||
def run(
|
||||
self,
|
||||
*,
|
||||
max_new_tokens: int = 16,
|
||||
stop: Union[str, List[str]] = (),
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
frequency_penalty: float = 0.0,
|
||||
presence_penalty: float = 0.0,
|
||||
backend=None,
|
||||
**kwargs,
|
||||
):
|
||||
backend = backend or global_config.default_backend
|
||||
|
||||
kwargs = {k: SglArgument(k, v) for k, v in kwargs.items()}
|
||||
kwargs.update(self.function.bind_arguments)
|
||||
|
||||
default_sampling_para = SamplingParams(
|
||||
max_new_tokens=max_new_tokens,
|
||||
stop=stop,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
)
|
||||
|
||||
return self.run_internal(backend, kwargs, default_sampling_para)
|
||||
|
||||
def run_batch(
|
||||
self,
|
||||
batch_kwargs,
|
||||
*,
|
||||
max_new_tokens: int = 16,
|
||||
stop: Union[str, List[str]] = (),
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
frequency_penalty: float = 0.0,
|
||||
presence_penalty: float = 0.0,
|
||||
backend=None,
|
||||
num_threads: Union[str, int] = "auto",
|
||||
):
|
||||
assert isinstance(batch_kwargs, (list, tuple))
|
||||
if len(batch_kwargs) == 0:
|
||||
return []
|
||||
assert isinstance(batch_kwargs[0], dict)
|
||||
|
||||
backend = backend or global_config.default_backend
|
||||
|
||||
default_sampling_para = SamplingParams(
|
||||
max_new_tokens=max_new_tokens,
|
||||
stop=stop,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
)
|
||||
batch_kwargs = [
|
||||
{k: SglArgument(k, v) for k, v in kwargs.items()} for kwargs in batch_kwargs
|
||||
]
|
||||
|
||||
# Extract prefix by tracing and cache it
|
||||
if len(batch_kwargs) > 1:
|
||||
pin_program(self.function, backend)
|
||||
|
||||
# Run all programs
|
||||
if num_threads == "auto":
|
||||
num_threads = multiprocessing.cpu_count()
|
||||
num_threads = min(num_threads, len(batch_kwargs))
|
||||
|
||||
if num_threads == 1:
|
||||
rets = []
|
||||
for arguments in batch_kwargs:
|
||||
rets.append(
|
||||
self.run_internal(backend, arguments, default_sampling_para)
|
||||
)
|
||||
else:
|
||||
with ThreadPoolExecutor(num_threads) as executor:
|
||||
futures = []
|
||||
for arguments in batch_kwargs:
|
||||
futures.append(
|
||||
executor.submit(
|
||||
self.run_internal, backend, arguments, default_sampling_para
|
||||
)
|
||||
)
|
||||
rets = [f.result() for f in futures]
|
||||
rets[-1].sync()
|
||||
|
||||
return rets
|
||||
|
||||
|
||||
class CompGraphNode:
|
||||
def __init__(
|
||||
self, expr: SglExpr, prev_node=None, next_nodes=None, source_node=None
|
||||
):
|
||||
self.expr = expr
|
||||
self.next_nodes = next_nodes or []
|
||||
self.prev_node = prev_node
|
||||
self.source_node = source_node
|
||||
|
||||
def add_next_node(self, other):
|
||||
self.next_nodes.append(other)
|
||||
|
||||
def __repr__(self):
|
||||
re = f"stream {self.expr.pid:2d}: "
|
||||
re += f"%{self.expr.node_id} = "
|
||||
if self.prev_node is not None:
|
||||
re += f"%{self.prev_node.expr.node_id} + "
|
||||
re += repr(self.expr)
|
||||
return re
|
||||
697
python/sglang/lang/interpreter.py
Normal file
697
python/sglang/lang/interpreter.py
Normal file
@@ -0,0 +1,697 @@
|
||||
"""The interpreter that executes SGL programs"""
|
||||
|
||||
import asyncio
|
||||
import multiprocessing
|
||||
import queue
|
||||
import threading
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import tqdm
|
||||
from sglang.global_config import global_config
|
||||
from sglang.lang.ir import (
|
||||
SglArgument,
|
||||
SglCommitLazy,
|
||||
SglConcateAndAppend,
|
||||
SglConstantText,
|
||||
SglExpr,
|
||||
SglExprList,
|
||||
SglFunction,
|
||||
SglGen,
|
||||
SglImage,
|
||||
SglRoleBegin,
|
||||
SglRoleEnd,
|
||||
SglSelect,
|
||||
SglVariable,
|
||||
SglVarScopeBegin,
|
||||
SglVarScopeEnd,
|
||||
)
|
||||
from sglang.utils import encode_image_base64
|
||||
|
||||
|
||||
def run_internal(state, program, func_args, func_kwargs, sync):
|
||||
try:
|
||||
state.ret_value = program.func(state, *func_args, **func_kwargs)
|
||||
except Exception as e:
|
||||
raise e
|
||||
finally:
|
||||
state.stream_executor.end()
|
||||
|
||||
if sync:
|
||||
state.stream_executor.sync()
|
||||
|
||||
if global_config.verbosity >= 2:
|
||||
print(state.text())
|
||||
|
||||
|
||||
def run_program(
|
||||
program, backend, func_args, func_kwargs, default_sampling_para, stream, sync=False
|
||||
):
|
||||
assert backend is not None, "Please specify a backend"
|
||||
func_kwargs.update(program.bind_arguments)
|
||||
stream_executor = StreamExecutor(
|
||||
backend, func_kwargs, default_sampling_para, chat_template=None, stream=stream
|
||||
)
|
||||
state = ProgramState(stream_executor)
|
||||
|
||||
if stream:
|
||||
t = threading.Thread(
|
||||
target=run_internal, args=(state, program, func_args, func_kwargs, sync)
|
||||
)
|
||||
t.start()
|
||||
return state
|
||||
else:
|
||||
run_internal(state, program, func_args, func_kwargs, sync)
|
||||
return state
|
||||
|
||||
|
||||
def run_program_batch(
|
||||
program,
|
||||
backend,
|
||||
batch_arguments,
|
||||
default_sampling_para,
|
||||
num_threads,
|
||||
progress_bar,
|
||||
):
|
||||
# Extract prefix by tracing and cache it
|
||||
if len(batch_arguments) > 1:
|
||||
pin_program(program, backend)
|
||||
|
||||
# Run all programs
|
||||
if num_threads == "auto":
|
||||
num_threads = multiprocessing.cpu_count()
|
||||
num_threads = min(num_threads, len(batch_arguments))
|
||||
|
||||
if num_threads == 1:
|
||||
rets = []
|
||||
for arguments in batch_arguments:
|
||||
rets.append(
|
||||
run_program(
|
||||
program, backend, (), arguments, default_sampling_para, False, False
|
||||
)
|
||||
)
|
||||
else:
|
||||
if progress_bar:
|
||||
pbar = tqdm.tqdm(total=len(batch_arguments))
|
||||
|
||||
with ThreadPoolExecutor(num_threads) as executor:
|
||||
futures = []
|
||||
for arguments in batch_arguments:
|
||||
futures.append(
|
||||
executor.submit(
|
||||
run_program,
|
||||
program,
|
||||
backend,
|
||||
(),
|
||||
arguments,
|
||||
default_sampling_para,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
)
|
||||
if progress_bar:
|
||||
futures[-1].add_done_callback(lambda _: pbar.update())
|
||||
|
||||
rets = [f.result() for f in futures]
|
||||
rets[-1].sync()
|
||||
|
||||
if progress_bar:
|
||||
pbar.close()
|
||||
|
||||
return rets
|
||||
|
||||
|
||||
def pin_program(program, backend):
|
||||
if global_config.enable_prefix_sharing and program.pin_prefix_rid is None:
|
||||
# TODO: handle multiple backends
|
||||
from sglang.lang.tracer import extract_prefix_by_tracing
|
||||
|
||||
prefix = extract_prefix_by_tracing(program, backend)
|
||||
if prefix and len(prefix) > 64:
|
||||
prefix_rid = backend.cache_prefix(prefix)
|
||||
program.pin_prefix_rid = prefix_rid
|
||||
return prefix_rid
|
||||
return None
|
||||
|
||||
|
||||
def unpin_program(program, backend):
|
||||
pass
|
||||
|
||||
|
||||
class StreamExecutor:
|
||||
"""A stream executor that executes SGL expressions in a background thread."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backend,
|
||||
arguments,
|
||||
default_sampling_para,
|
||||
chat_template,
|
||||
stream,
|
||||
use_thread=True,
|
||||
):
|
||||
self.sid = uuid.uuid4().hex
|
||||
self.backend = backend
|
||||
self.arguments: Dict[str, Any] = arguments
|
||||
self.default_sampling_para = default_sampling_para
|
||||
self.stream = stream
|
||||
|
||||
if hasattr(backend, "endpoint"):
|
||||
self.backend = backend.endpoint
|
||||
|
||||
self.variables = {} # Dict[name: str -> value: str]
|
||||
self.variable_event = {} # Dict[name: str -> event: threading.Event]
|
||||
self.meta_info = {} # Dict[name: str -> info: str]
|
||||
self.is_finished = False
|
||||
|
||||
# For completion
|
||||
self.text_ = "" # The full text
|
||||
|
||||
# For chat
|
||||
self.messages_ = [] # The messages in the OpenAI API format
|
||||
self.chat_template = chat_template or self.backend.get_chat_template()
|
||||
self.cur_role = None
|
||||
self.cur_role_begin_pos = None
|
||||
|
||||
# For vision
|
||||
self.images_ = []
|
||||
self.cur_images = []
|
||||
|
||||
# For fork/join
|
||||
self.fork_start_text_pos = None
|
||||
|
||||
# Worker thread
|
||||
self.use_thread = use_thread
|
||||
if self.use_thread:
|
||||
self.queue = queue.Queue()
|
||||
self.worker = threading.Thread(target=self._thread_worker_func)
|
||||
self.worker.start()
|
||||
|
||||
# For streaming
|
||||
if stream:
|
||||
self.stream_text_event = threading.Event()
|
||||
self.stream_var_event = {}
|
||||
else:
|
||||
self.stream_text_event = None
|
||||
self.stream_var_event = None
|
||||
|
||||
def submit(self, expr: SglExpr):
|
||||
if isinstance(expr, (SglGen, SglSelect, SglVarScopeBegin)):
|
||||
self.variable_event[expr.name] = threading.Event()
|
||||
if self.stream:
|
||||
self.stream_var_event[expr.name] = threading.Event()
|
||||
elif isinstance(expr, SglExprList):
|
||||
for e in expr.expr_list:
|
||||
if isinstance(e, (SglGen, SglSelect, SglVarScopeBegin)):
|
||||
self.variable_event[e.name] = threading.Event()
|
||||
if self.stream:
|
||||
self.stream_var_event[e.name] = threading.Event()
|
||||
|
||||
if self.use_thread:
|
||||
self.queue.put(expr)
|
||||
else:
|
||||
self._execute(expr)
|
||||
|
||||
def sync(self):
|
||||
if self.use_thread:
|
||||
self.queue.join()
|
||||
|
||||
def get_var(self, name):
|
||||
if name in self.variable_event:
|
||||
self.variable_event[name].wait()
|
||||
return self.variables[name]
|
||||
|
||||
def get_meta_info(self, name):
|
||||
if name in self.variable_event:
|
||||
self.variable_event[name].wait()
|
||||
ret = self.meta_info.get(name, None)
|
||||
return ret
|
||||
|
||||
def fork(self, number: int, position_ids_offset: Optional[List[int]] = None):
|
||||
if number > 1:
|
||||
self.submit(SglCommitLazy())
|
||||
self.sync()
|
||||
|
||||
number = int(number)
|
||||
|
||||
exes = [
|
||||
StreamExecutor(
|
||||
self.backend,
|
||||
self.arguments,
|
||||
self.default_sampling_para,
|
||||
self.chat_template,
|
||||
self.stream,
|
||||
)
|
||||
for _ in range(number)
|
||||
]
|
||||
for i in range(number):
|
||||
exes[i].variables = dict(self.variables)
|
||||
exes[i].text_ = str(self.text_)
|
||||
exes[i].messages_ = list(self.messages_)
|
||||
exes[i].cur_role = self.cur_role
|
||||
exes[i].fork_start_text_pos = len(self.text_)
|
||||
|
||||
return exes
|
||||
|
||||
def text(self):
|
||||
self.sync()
|
||||
return self.text_
|
||||
|
||||
def messages(self):
|
||||
self.sync()
|
||||
return self.messages_
|
||||
|
||||
def end(self):
|
||||
if self.use_thread:
|
||||
if self.worker.is_alive():
|
||||
self.queue.put(None)
|
||||
self.backend.end_program(self)
|
||||
|
||||
def _thread_worker_func(self):
|
||||
while True:
|
||||
expr = self.queue.get()
|
||||
if expr is None:
|
||||
self.queue.task_done()
|
||||
break
|
||||
|
||||
self._execute(expr)
|
||||
self.queue.task_done()
|
||||
if self.stream_text_event:
|
||||
self.stream_text_event.set()
|
||||
|
||||
if self.stream_text_event:
|
||||
self.stream_text_event.set()
|
||||
|
||||
self.is_finished = True
|
||||
|
||||
def _execute(self, other):
|
||||
if isinstance(other, str):
|
||||
other = SglConstantText(other)
|
||||
|
||||
assert isinstance(other, SglExpr), f"{other}"
|
||||
|
||||
if isinstance(other, (SglConstantText, SglArgument)):
|
||||
self._execute_fill(other.value)
|
||||
elif isinstance(other, SglGen):
|
||||
self._execute_gen(other)
|
||||
elif isinstance(other, SglSelect):
|
||||
self._execute_select(other)
|
||||
elif isinstance(other, SglExprList):
|
||||
for x in other.expr_list:
|
||||
self._execute(x)
|
||||
elif isinstance(other, SglRoleBegin):
|
||||
self._execute_role_begin(other)
|
||||
elif isinstance(other, SglRoleEnd):
|
||||
self._execute_role_end(other)
|
||||
elif isinstance(other, SglImage):
|
||||
self._execute_image(other)
|
||||
elif isinstance(other, SglVariable):
|
||||
self._execute_variable(other)
|
||||
elif isinstance(other, SglVarScopeBegin):
|
||||
self._execute_var_scope_begin(other)
|
||||
elif isinstance(other, SglVarScopeEnd):
|
||||
self._execute_var_scope_end(other)
|
||||
elif isinstance(other, SglCommitLazy):
|
||||
self._execute_commit_lazy_operations(other)
|
||||
elif isinstance(other, SglConcateAndAppend):
|
||||
if (
|
||||
global_config.enable_parallel_encoding
|
||||
and self.backend.support_concate_and_append
|
||||
):
|
||||
self._execute_concatenate_and_append_kv_cache(other)
|
||||
else:
|
||||
self._execute_concatenate_and_append_text(other)
|
||||
else:
|
||||
raise ValueError(f"Unknown type: {type(other)}")
|
||||
|
||||
def _execute_fill(self, value: str):
|
||||
value = str(value)
|
||||
self.text_ += value
|
||||
|
||||
def _execute_image(self, expr: SglImage):
|
||||
path = expr.path
|
||||
if isinstance(path, SglArgument):
|
||||
path = path.value
|
||||
|
||||
base64_data = encode_image_base64(path)
|
||||
|
||||
self.images_.append((path, base64_data))
|
||||
self.cur_images.append((path, base64_data))
|
||||
self.text_ += self.chat_template.image_token
|
||||
|
||||
# if global_config.eager_fill_image:
|
||||
# self.backend.fill_image(self)
|
||||
|
||||
def _execute_gen(self, expr: SglGen):
|
||||
sampling_params = self._resolve_sampling_params(expr.sampling_params)
|
||||
name = expr.name
|
||||
|
||||
if not self.stream:
|
||||
comp, meta_info = self.backend.generate(
|
||||
self, sampling_params=sampling_params
|
||||
)
|
||||
self.text_ += comp
|
||||
|
||||
self.variables[name] = comp
|
||||
self.meta_info[name] = meta_info
|
||||
self.variable_event[name].set()
|
||||
else:
|
||||
generator = self.backend.generate_stream(
|
||||
self, sampling_params=sampling_params
|
||||
)
|
||||
|
||||
self.stream_var_event[name].set()
|
||||
|
||||
self.variables[name] = ""
|
||||
for comp, meta_info in generator:
|
||||
self.text_ += comp
|
||||
self.variables[name] += comp
|
||||
self.stream_var_event[name].set()
|
||||
self.stream_text_event.set()
|
||||
|
||||
self.meta_info[name] = meta_info
|
||||
|
||||
self.variable_event[name].set()
|
||||
self.stream_var_event[name].set()
|
||||
|
||||
def _execute_select(self, expr: SglSelect):
|
||||
decision, scores = self.backend.select(self, expr.choices, expr.temperature)
|
||||
if expr.name is not None:
|
||||
name = expr.name
|
||||
self.variables[name] = decision
|
||||
self.variable_event[name].set()
|
||||
self.text_ += decision
|
||||
|
||||
def _execute_variable(self, expr: SglVariable):
|
||||
src_executor = expr.source_stream_executor
|
||||
value = src_executor.get_var(expr.name)
|
||||
self._execute_fill(value)
|
||||
|
||||
def _execute_role_begin(self, expr: SglRoleBegin):
|
||||
assert self.cur_role is None, "Nested roles are not allowed."
|
||||
|
||||
if len(self.messages_) == 0 and expr.role != "system":
|
||||
# Insert the default system message
|
||||
default_system = self.chat_template.default_system_prompt
|
||||
if default_system:
|
||||
self._execute_role_begin(SglRoleBegin("system"))
|
||||
self._execute_fill(default_system)
|
||||
self._execute_role_end(SglRoleEnd("system"))
|
||||
|
||||
self.cur_role = expr.role
|
||||
|
||||
prefix, _ = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_)
|
||||
|
||||
self._execute_fill(prefix)
|
||||
self.cur_role_begin_pos = len(self.text_)
|
||||
|
||||
def _execute_role_end(self, expr: SglRoleEnd):
|
||||
new_text = self.text_[self.cur_role_begin_pos :].lstrip()
|
||||
|
||||
_, suffix = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_)
|
||||
self._execute_fill(suffix)
|
||||
|
||||
if self.cur_images:
|
||||
# OpenAI vision API format
|
||||
last_msg = {
|
||||
"role": expr.role,
|
||||
"content": [{"type": "text", "text": new_text}],
|
||||
}
|
||||
for (image_path, image_base64_data) in self.cur_images:
|
||||
last_msg["content"].append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{image_base64_data}"
|
||||
},
|
||||
}
|
||||
)
|
||||
self.messages_.append(last_msg)
|
||||
self.cur_images = []
|
||||
else:
|
||||
self.messages_.append({"role": expr.role, "content": new_text})
|
||||
|
||||
self.cur_role = None
|
||||
|
||||
def _execute_var_scope_begin(self, expr: SglVarScopeBegin):
|
||||
self.variables[expr.name] = int(len(self.text_))
|
||||
|
||||
def _execute_var_scope_end(self, expr: SglVarScopeEnd):
|
||||
self.variables[expr.name] = self.text_[self.variables[expr.name] :]
|
||||
self.variable_event[expr.name].set()
|
||||
|
||||
def _execute_commit_lazy_operations(self, expr: SglCommitLazy):
|
||||
self.backend.commit_lazy_operations(self)
|
||||
|
||||
def _execute_concatenate_and_append_text(self, expr: SglConcateAndAppend):
|
||||
new_text = ""
|
||||
for s in expr.states:
|
||||
exe = s.stream_executor
|
||||
exe.sync()
|
||||
new_text += exe.text_[exe.fork_start_text_pos :]
|
||||
|
||||
self._execute_fill(new_text)
|
||||
|
||||
def _execute_concatenate_and_append_kv_cache(self, expr: SglConcateAndAppend):
|
||||
self_len = len(self.text_)
|
||||
|
||||
for i, s in enumerate(expr.states):
|
||||
exe = s.stream_executor
|
||||
exe.submit(SglCommitLazy())
|
||||
|
||||
for i, s in enumerate(expr.states):
|
||||
exe = s.stream_executor
|
||||
exe.sync()
|
||||
assert exe.fork_start_text_pos == self_len
|
||||
self.text_ += exe.text_[exe.fork_start_text_pos :]
|
||||
|
||||
src_rids = [state.stream_executor.sid for state in expr.states]
|
||||
self.backend.concatenate_and_append(src_rids, self.sid)
|
||||
|
||||
def _resolve_sampling_params(self, sampling_params):
|
||||
clone = None
|
||||
for item in [
|
||||
"max_new_tokens",
|
||||
"stop",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"top_k",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
"dtype",
|
||||
"regex",
|
||||
]:
|
||||
value = getattr(sampling_params, item, None)
|
||||
if value is not None:
|
||||
if clone is None:
|
||||
clone = self.default_sampling_para.clone()
|
||||
setattr(clone, item, value)
|
||||
return clone or self.default_sampling_para
|
||||
|
||||
def __del__(self):
|
||||
self.end()
|
||||
|
||||
|
||||
class ProgramState:
|
||||
"""The state of an SGL program."""
|
||||
|
||||
def __init__(self, stream_executor: StreamExecutor):
|
||||
self.stream_executor = stream_executor
|
||||
|
||||
def _role_common(self, name: str, expr: Optional[SglExpr] = None):
|
||||
if expr is not None:
|
||||
self.stream_executor.submit(
|
||||
SglExprList([SglRoleBegin(name), expr, SglRoleEnd(name)])
|
||||
)
|
||||
else:
|
||||
|
||||
@contextmanager
|
||||
def role_scope():
|
||||
self.stream_executor.submit(SglRoleBegin(name))
|
||||
yield
|
||||
self.stream_executor.submit(SglRoleEnd(name))
|
||||
|
||||
return role_scope()
|
||||
|
||||
def system(self, expr: Optional[SglExpr] = None):
|
||||
return self._role_common("system", expr)
|
||||
|
||||
def user(self, expr: Optional[SglExpr] = None):
|
||||
return self._role_common("user", expr)
|
||||
|
||||
def assistant(self, expr: Optional[SglExpr] = None):
|
||||
return self._role_common("assistant", expr)
|
||||
|
||||
@contextmanager
|
||||
def var_scope(self, name: str):
|
||||
self.stream_executor.submit(SglVarScopeBegin(name))
|
||||
yield
|
||||
self.stream_executor.submit(SglVarScopeEnd(name))
|
||||
|
||||
def fork(self, number: int = 1, position_ids_offset: Optional[List[int]] = None):
|
||||
stream_executors = self.stream_executor.fork(number, position_ids_offset)
|
||||
states = [ProgramState(x) for x in stream_executors]
|
||||
state_group = ProgramStateGroup(states, self)
|
||||
return state_group
|
||||
|
||||
@contextmanager
|
||||
def copy(self, position_ids_offset: Optional[List[int]] = None):
|
||||
state_group = self.fork(1, position_ids_offset)
|
||||
try:
|
||||
yield state_group[0]
|
||||
finally:
|
||||
state_group.join()
|
||||
|
||||
def text(self):
|
||||
return self.stream_executor.text()
|
||||
|
||||
def messages(self):
|
||||
return self.stream_executor.messages()
|
||||
|
||||
def sync(self):
|
||||
return self.stream_executor.sync()
|
||||
|
||||
def text_iter(self, var_name=None):
|
||||
if self.stream_executor.stream:
|
||||
prev = 0
|
||||
if var_name is None:
|
||||
event = self.stream_executor.stream_text_event
|
||||
while True:
|
||||
event.wait()
|
||||
event.clear()
|
||||
out = str(self.stream_executor.text_[prev:])
|
||||
prev += len(out)
|
||||
if out:
|
||||
yield out
|
||||
if self.stream_executor.is_finished:
|
||||
break
|
||||
else:
|
||||
event = self.stream_executor.stream_var_event[var_name]
|
||||
while True:
|
||||
event.wait()
|
||||
event.clear()
|
||||
out = str(self.stream_executor.variables[var_name][prev:])
|
||||
prev += len(out)
|
||||
if out:
|
||||
yield out
|
||||
if self.stream_executor.variable_event[var_name].is_set():
|
||||
break
|
||||
else:
|
||||
if var_name is None:
|
||||
yield self.text()
|
||||
else:
|
||||
yield self.get_var(name)
|
||||
|
||||
async def text_async_iter(self, var_name=None):
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
if self.stream_executor.stream:
|
||||
prev = 0
|
||||
if var_name is None:
|
||||
event = self.stream_executor.stream_text_event
|
||||
while True:
|
||||
await loop.run_in_executor(None, event.wait)
|
||||
event.clear()
|
||||
out = str(self.stream_executor.text_[prev:])
|
||||
prev += len(out)
|
||||
if out:
|
||||
yield out
|
||||
if self.stream_executor.is_finished:
|
||||
break
|
||||
else:
|
||||
event = self.stream_executor.stream_var_event[var_name]
|
||||
while True:
|
||||
await loop.run_in_executor(None, event.wait)
|
||||
event.clear()
|
||||
out = str(self.stream_executor.variables[var_name][prev:])
|
||||
prev += len(out)
|
||||
if out:
|
||||
yield out
|
||||
if self.stream_executor.variable_event[var_name].is_set():
|
||||
break
|
||||
else:
|
||||
if var_name is None:
|
||||
yield self.text()
|
||||
else:
|
||||
yield self.get_var(name)
|
||||
|
||||
def get_var(self, name):
|
||||
return self.stream_executor.get_var(name)
|
||||
|
||||
def get_meta_info(self, name):
|
||||
return self.stream_executor.get_meta_info(name)
|
||||
|
||||
def __iadd__(self, other):
|
||||
self.stream_executor.submit(other)
|
||||
return self
|
||||
|
||||
def __getitem__(self, name):
|
||||
return self.get_var(name)
|
||||
|
||||
def __del__(self):
|
||||
self.stream_executor.end()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
msgs = self.messages()
|
||||
ret = ""
|
||||
for msg in msgs:
|
||||
ret += msg["role"] + ":\n" + msg["content"] + "\n"
|
||||
return ret
|
||||
|
||||
|
||||
class ProgramStateGroup:
|
||||
def __init__(
|
||||
self, states: List[ProgramState], src_state: Optional[ProgramState] = None
|
||||
):
|
||||
self.states = states
|
||||
self.src_state = src_state
|
||||
|
||||
def join(self, mode: str = "gather_variable"):
|
||||
if mode == "gather_variable":
|
||||
# Copy variables back
|
||||
src_vars = self.src_state.stream_executor.variables
|
||||
src_var_set = set(src_vars.keys())
|
||||
for child_state in self.states:
|
||||
child_state.stream_executor.sync()
|
||||
child_vars = child_state.stream_executor.variables
|
||||
new_vars = set(child_vars.keys()) - src_var_set
|
||||
|
||||
for k in new_vars:
|
||||
if k in src_vars:
|
||||
src_vars[k].append(child_vars[k])
|
||||
else:
|
||||
src_vars[k] = [child_vars[k]]
|
||||
elif mode == "concate_and_append":
|
||||
# Concatenate and append KV cache
|
||||
self.src_state += SglConcateAndAppend(self.states)
|
||||
# Need a sync here. Otherwise, `states` can be deleted.
|
||||
self.src_state.stream_executor.sync()
|
||||
else:
|
||||
raise ValueError(f"Invalid join mode: {mode}")
|
||||
|
||||
for s in self.states:
|
||||
s.stream_executor.end()
|
||||
|
||||
def __getitem__(self, i: int):
|
||||
return self.states[i]
|
||||
|
||||
def __setitem__(self, i: int, value):
|
||||
assert self.states[i] == value
|
||||
|
||||
def __iadd__(self, other):
|
||||
if isinstance(other, Callable):
|
||||
# lambda function
|
||||
for i in range(len(self.states)):
|
||||
self.states[i] += other(i)
|
||||
elif isinstance(other, SglExpr):
|
||||
for i in range(len(self.states)):
|
||||
self.states[i] += other
|
||||
elif isinstance(other, (list, tuple)):
|
||||
for i in range(len(self.states)):
|
||||
self.states[i] += other[i]
|
||||
else:
|
||||
raise ValueError(f"Invalid value: {other}")
|
||||
|
||||
return self
|
||||
442
python/sglang/lang/ir.py
Normal file
442
python/sglang/lang/ir.py
Normal file
@@ -0,0 +1,442 @@
|
||||
"""The intermediate representation."""
|
||||
|
||||
import dataclasses
|
||||
import inspect
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from sglang.global_config import global_config
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SamplingParams:
|
||||
max_new_tokens: int = 16
|
||||
stop: Union[str, List[str]] = ()
|
||||
temperature: float = 1.0
|
||||
top_p: float = 1.0
|
||||
top_k: int = -1 # -1 means disable
|
||||
frequency_penalty: float = 0.0
|
||||
presence_penalty: float = 0.0
|
||||
|
||||
# for constrained generation, not included in to_xxx_kwargs
|
||||
dtype: Optional[str] = None
|
||||
regex: Optional[str] = None
|
||||
|
||||
def clone(self):
|
||||
return SamplingParams(
|
||||
self.max_new_tokens,
|
||||
self.stop,
|
||||
self.temperature,
|
||||
self.top_p,
|
||||
self.top_k,
|
||||
self.frequency_penalty,
|
||||
self.presence_penalty,
|
||||
)
|
||||
|
||||
def to_openai_kwargs(self):
|
||||
# OpenAI does not support top_k, so we drop it here
|
||||
return {
|
||||
"max_tokens": self.max_new_tokens,
|
||||
"stop": self.stop or None,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
"presence_penalty": self.presence_penalty,
|
||||
}
|
||||
|
||||
def to_anthropic_kwargs(self):
|
||||
# Anthropic does not support frequency_penalty or presence_penalty, so we drop it here
|
||||
return {
|
||||
"max_tokens_to_sample": self.max_new_tokens,
|
||||
"stop_sequences": self.stop,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"top_k": self.top_k,
|
||||
}
|
||||
|
||||
def to_srt_kwargs(self):
|
||||
return {
|
||||
"max_new_tokens": self.max_new_tokens,
|
||||
"stop": self.stop,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"top_k": self.top_k,
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
"presence_penalty": self.presence_penalty,
|
||||
"regex": self.regex,
|
||||
}
|
||||
|
||||
|
||||
class SglFunction:
|
||||
def __init__(self, func, bind_arguments=None):
|
||||
self.func = func
|
||||
self.bind_arguments = bind_arguments or {}
|
||||
self.pin_prefix_rid = None
|
||||
|
||||
# Parse arguments
|
||||
argspec = inspect.getfullargspec(func)
|
||||
assert argspec.args[0] == "s", 'The first argument must be "s"'
|
||||
self.arg_names = argspec.args[1:]
|
||||
|
||||
def bind(self, **kwargs):
|
||||
assert all(key in self.arg_names for key in kwargs)
|
||||
|
||||
new_bind_dict = {**self.bind_arguments, **kwargs}
|
||||
return SglFunction(self.func, bind_arguments=new_bind_dict)
|
||||
|
||||
def run(
|
||||
self,
|
||||
*args,
|
||||
max_new_tokens: int = 16,
|
||||
stop: Union[str, List[str]] = (),
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
frequency_penalty: float = 0.0,
|
||||
presence_penalty: float = 0.0,
|
||||
stream: bool = False,
|
||||
backend=None,
|
||||
**kwargs,
|
||||
):
|
||||
from sglang.lang.interpreter import run_program
|
||||
|
||||
default_sampling_para = SamplingParams(
|
||||
max_new_tokens=max_new_tokens,
|
||||
stop=stop,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
)
|
||||
backend = backend or global_config.default_backend
|
||||
kwargs = {k: SglArgument(k, v) for k, v in kwargs.items()}
|
||||
return run_program(self, backend, args, kwargs, default_sampling_para, stream)
|
||||
|
||||
def run_batch(
|
||||
self,
|
||||
batch_kwargs,
|
||||
*,
|
||||
max_new_tokens: int = 16,
|
||||
stop: Union[str, List[str]] = (),
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
frequency_penalty: float = 0.0,
|
||||
presence_penalty: float = 0.0,
|
||||
backend=None,
|
||||
num_threads: Union[str, int] = "auto",
|
||||
progress_bar: bool = False,
|
||||
):
|
||||
from sglang.lang.interpreter import run_program_batch
|
||||
|
||||
assert isinstance(batch_kwargs, (list, tuple))
|
||||
if len(batch_kwargs) == 0:
|
||||
return []
|
||||
assert isinstance(batch_kwargs[0], dict)
|
||||
|
||||
default_sampling_para = SamplingParams(
|
||||
max_new_tokens=max_new_tokens,
|
||||
stop=stop,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
)
|
||||
backend = backend or global_config.default_backend
|
||||
batch_kwargs = [
|
||||
{k: SglArgument(k, v) for k, v in kwargs.items()} for kwargs in batch_kwargs
|
||||
]
|
||||
return run_program_batch(
|
||||
self,
|
||||
backend,
|
||||
batch_kwargs,
|
||||
default_sampling_para,
|
||||
num_threads,
|
||||
progress_bar,
|
||||
)
|
||||
|
||||
def trace(self, *, backend=None, **kwargs):
|
||||
from sglang.lang.tracer import trace_program
|
||||
|
||||
backend = backend or global_config.default_backend
|
||||
return trace_program(self, kwargs, backend)
|
||||
|
||||
def pin(self, backend=None):
|
||||
from sglang.lang.interpreter import pin_program
|
||||
|
||||
backend = backend or global_config.default_backend
|
||||
return pin_program(self, backend)
|
||||
|
||||
def unpin(self, backend=None):
|
||||
from sglang.lang.interpreter import unpin_program
|
||||
|
||||
backend = backend or global_config.default_backend
|
||||
return unpin_program(self, backend)
|
||||
|
||||
def compile(self, *, backend=None):
|
||||
from sglang.lang.compiler import compile_func
|
||||
|
||||
return compile_func(self, backend)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
from sglang.lang.tracer import TracingScope
|
||||
|
||||
tracing_scope = TracingScope.get_current_scope()
|
||||
if tracing_scope is None:
|
||||
return self.run(*args, **kwargs)
|
||||
else:
|
||||
kwargs["backend"] = tracing_scope.tracer_state.backend
|
||||
return self.trace(*args, **kwargs)
|
||||
|
||||
|
||||
class SglExpr:
|
||||
node_ct = 0
|
||||
|
||||
def __init__(self):
|
||||
self.node_id = SglExpr.node_ct
|
||||
self.prev_node = None
|
||||
self.pid = None
|
||||
SglExpr.node_ct += 1
|
||||
|
||||
def __add__(self, other):
|
||||
if isinstance(other, str):
|
||||
other = SglConstantText(other)
|
||||
assert isinstance(other, SglExpr)
|
||||
|
||||
return self.concatenate_ir(self, other)
|
||||
|
||||
def __radd__(self, other):
|
||||
if isinstance(other, str):
|
||||
other = SglConstantText(other)
|
||||
assert isinstance(other, SglExpr), f"{other}"
|
||||
|
||||
return self.concatenate_ir(other, self)
|
||||
|
||||
def concatenate_ir(self, a, b):
|
||||
if isinstance(a, SglExprList):
|
||||
if isinstance(b, SglExprList):
|
||||
return SglExprList(a.expr_list + b.expr_list)
|
||||
else:
|
||||
return SglExprList(a.expr_list + [b])
|
||||
elif isinstance(b, SglExprList):
|
||||
return SglExprList([a] + b.expr_list)
|
||||
|
||||
return SglExprList([a, b])
|
||||
|
||||
def print_graph_dfs(self):
|
||||
ret = [""]
|
||||
visited = set()
|
||||
|
||||
def dfs_print(x):
|
||||
if x is None or x in visited:
|
||||
return
|
||||
visited.add(x)
|
||||
|
||||
# Print dependency
|
||||
if x.prev_node is not None:
|
||||
dfs_print(x.prev_node)
|
||||
|
||||
if isinstance(x, SglExprList):
|
||||
for y in x.expr_list:
|
||||
dfs_print(y)
|
||||
# elif isinstance(x, SglRole):
|
||||
# dfs_print(x.expr)
|
||||
elif isinstance(x, SglVariable):
|
||||
dfs_print(x.source)
|
||||
|
||||
# Print the node itself
|
||||
if isinstance(x, (SglFork, SglGetForkItem)):
|
||||
ret[0] += f"%{x.node_id} = {x}\n"
|
||||
else:
|
||||
if x.prev_node is not None:
|
||||
ret[0] += (
|
||||
f"%{x.node_id} = %{x.prev_node.node_id} + " + str(x) + "\n"
|
||||
)
|
||||
else:
|
||||
ret[0] += f"%{x.node_id} = " + str(x) + "\n"
|
||||
|
||||
dfs_print(self)
|
||||
return ret[0]
|
||||
|
||||
|
||||
class SglExprList(SglExpr):
|
||||
def __init__(self, expr_list: List[SglExpr]):
|
||||
super().__init__()
|
||||
self.expr_list = expr_list
|
||||
|
||||
def __repr__(self):
|
||||
return f"ExprList({self.expr_list})"
|
||||
|
||||
|
||||
class SglArgument(SglExpr):
|
||||
def __init__(self, name: str, value: str):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self.value = value
|
||||
|
||||
def __repr__(self):
|
||||
return f"Argument(name={self.name}, value={repr(self.value)})"
|
||||
|
||||
def __len__(self):
|
||||
return len(self.value)
|
||||
|
||||
def __getitem__(self, i):
|
||||
return self.value[i]
|
||||
|
||||
def __int__(self):
|
||||
return self.value
|
||||
|
||||
def __bool__(self):
|
||||
return self.value
|
||||
|
||||
def __format__(self, *args):
|
||||
raise TypeError(
|
||||
"Cannot put argument inside a f-string. "
|
||||
"This is not compatible with the tracer. "
|
||||
)
|
||||
|
||||
|
||||
class SglImage(SglExpr):
|
||||
def __init__(self, path):
|
||||
self.path = path
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"SglImage({self.path})"
|
||||
|
||||
|
||||
class SglGen(SglExpr):
|
||||
def __init__(
|
||||
self,
|
||||
name,
|
||||
max_new_tokens,
|
||||
stop,
|
||||
temperature,
|
||||
top_p,
|
||||
top_k,
|
||||
frequency_penalty,
|
||||
presence_penalty,
|
||||
dtype,
|
||||
regex,
|
||||
):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self.sampling_params = SamplingParams(
|
||||
max_new_tokens=max_new_tokens,
|
||||
stop=stop,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
dtype=dtype,
|
||||
regex=regex,
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"Gen('{self.name}')"
|
||||
|
||||
|
||||
class SglConstantText(SglExpr):
|
||||
def __init__(self, value):
|
||||
super().__init__()
|
||||
self.value = value
|
||||
|
||||
def __repr__(self):
|
||||
return f"Constant({repr(self.value)})"
|
||||
|
||||
|
||||
class SglRoleBegin(SglExpr):
|
||||
def __init__(self, role):
|
||||
super().__init__()
|
||||
self.role = role
|
||||
|
||||
def __repr__(self):
|
||||
return f"RoleBegin({self.role})"
|
||||
|
||||
|
||||
class SglRoleEnd(SglExpr):
|
||||
def __init__(self, role):
|
||||
super().__init__()
|
||||
self.role = role
|
||||
|
||||
def __repr__(self):
|
||||
return f"RoleEnd({self.role})"
|
||||
|
||||
|
||||
class SglSelect(SglExpr):
|
||||
def __init__(self, name, choices, temperature):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self.choices = choices
|
||||
self.temperature = temperature
|
||||
|
||||
def __repr__(self):
|
||||
return f"Select({self.name}, choices={self.choices})"
|
||||
|
||||
|
||||
class SglFork(SglExpr):
|
||||
def __init__(self, number, position_ids_offset=None):
|
||||
super().__init__()
|
||||
self.number = number
|
||||
self.position_ids_offset = position_ids_offset
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"Fork(%{self.prev_node.node_id}, number={self.number}, "
|
||||
f"position_ids_offset={self.position_ids_offset})"
|
||||
)
|
||||
|
||||
|
||||
class SglGetForkItem(SglExpr):
|
||||
def __init__(self, index):
|
||||
super().__init__()
|
||||
self.index = index
|
||||
|
||||
def __repr__(self):
|
||||
return f"GetForkItem(%{self.prev_node.node_id}, index={self.index})"
|
||||
|
||||
|
||||
class SglVariable(SglExpr):
|
||||
def __init__(self, name, source):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self.source = source
|
||||
|
||||
def __repr__(self):
|
||||
return f"Variable('{self.name}', source=%{self.source.node_id})"
|
||||
|
||||
|
||||
class SglVarScopeBegin(SglExpr):
|
||||
def __init__(self, name):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
|
||||
def __repr__(self):
|
||||
return f"VarScopeBegin('{self.name}')"
|
||||
|
||||
|
||||
class SglVarScopeEnd(SglExpr):
|
||||
def __init__(self, name):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
|
||||
def __repr__(self):
|
||||
return f"VarScopeEnd('{self.name}')"
|
||||
|
||||
|
||||
class SglConcateAndAppend(SglExpr):
|
||||
def __init__(self, states):
|
||||
super().__init__()
|
||||
self.states = states
|
||||
|
||||
def __repr__(self):
|
||||
return f"ConcatenateAndAppend('{self.states}')"
|
||||
|
||||
|
||||
class SglCommitLazy(SglExpr):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def __repr__(self):
|
||||
return f"CommitLazy()"
|
||||
279
python/sglang/lang/tracer.py
Normal file
279
python/sglang/lang/tracer.py
Normal file
@@ -0,0 +1,279 @@
|
||||
"""Tracing a program."""
|
||||
import uuid
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
from sglang.backend.base_backend import BaseBackend
|
||||
from sglang.global_config import global_config
|
||||
from sglang.lang.interpreter import ProgramState, ProgramStateGroup
|
||||
from sglang.lang.ir import (
|
||||
SglArgument,
|
||||
SglCommitLazy,
|
||||
SglConcateAndAppend,
|
||||
SglConstantText,
|
||||
SglExpr,
|
||||
SglExprList,
|
||||
SglFork,
|
||||
SglFunction,
|
||||
SglGen,
|
||||
SglGetForkItem,
|
||||
SglRoleBegin,
|
||||
SglRoleEnd,
|
||||
SglSelect,
|
||||
SglVariable,
|
||||
SglVarScopeBegin,
|
||||
SglVarScopeEnd,
|
||||
)
|
||||
|
||||
|
||||
class StopTracing(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def extract_prefix_by_tracing(program, backend):
|
||||
# Create dummy arguments
|
||||
dummy_arguments = {name: SglArgument(name, None) for name in program.arg_names}
|
||||
arguments = dummy_arguments
|
||||
arguments.update(program.bind_arguments)
|
||||
|
||||
# Trace
|
||||
tracer = TracerProgramState(backend, arguments, only_trace_prefix=True)
|
||||
try:
|
||||
with TracingScope(tracer):
|
||||
tracer.ret_value = program.func(tracer, **arguments)
|
||||
except StopTracing:
|
||||
pass
|
||||
|
||||
# Run and cache prefix
|
||||
prefix = ""
|
||||
for expr in tracer.flatten_nodes():
|
||||
if isinstance(expr, SglConstantText):
|
||||
prefix += expr.value
|
||||
else:
|
||||
break
|
||||
return prefix
|
||||
|
||||
|
||||
def trace_program(program, arguments, backend):
|
||||
# Create dummy backend
|
||||
if backend is None:
|
||||
backend = BaseBackend()
|
||||
|
||||
# Create dummy arguments
|
||||
dummy_arguments = {
|
||||
name: SglArgument(name, None)
|
||||
for name in program.arg_names
|
||||
if name not in arguments
|
||||
}
|
||||
arguments.update(dummy_arguments)
|
||||
arguments.update(program.bind_arguments)
|
||||
|
||||
# Trace
|
||||
tracer = TracerProgramState(backend, arguments, only_trace_prefix=False)
|
||||
with TracingScope(tracer):
|
||||
tracer.ret_value = program.func(tracer, **arguments)
|
||||
return tracer
|
||||
|
||||
|
||||
class TracerProgramState(ProgramState):
|
||||
def __init__(self, backend, arguments, only_trace_prefix):
|
||||
self.pid = uuid.uuid4().hex
|
||||
self.backend = backend
|
||||
self.arguments: Dict[str, Any] = arguments
|
||||
self.only_trace_prefix = only_trace_prefix
|
||||
|
||||
if hasattr(backend, "endpoint"):
|
||||
self.backend = backend.endpoint
|
||||
|
||||
self.nodes = []
|
||||
self.last_node = None
|
||||
self.variables = {}
|
||||
self.ret_value = None
|
||||
|
||||
# For completion
|
||||
|
||||
# For chat
|
||||
self.messages_ = []
|
||||
self.cur_role = None
|
||||
self.chat_template = self.backend.get_chat_template()
|
||||
|
||||
# For multi states
|
||||
self.child_states = []
|
||||
|
||||
cur_scope = TracingScope.get_current_scope()
|
||||
if cur_scope is not None:
|
||||
cur_scope.add_child_state(self)
|
||||
|
||||
##################################
|
||||
########### Public API ###########
|
||||
##################################
|
||||
|
||||
def fork(self, number: int, position_ids_offset: Optional[List[int]] = None):
|
||||
if self.only_trace_prefix:
|
||||
raise StopTracing()
|
||||
|
||||
fork_node = SglFork(number)
|
||||
fork_node.prev_node = self.last_node
|
||||
|
||||
states = [
|
||||
TracerProgramState(self.backend, self.arguments, self.only_trace_prefix)
|
||||
for _ in range(number)
|
||||
]
|
||||
|
||||
for i in range(number):
|
||||
node = SglGetForkItem(i)
|
||||
node.prev_node = fork_node
|
||||
states[i].last_node = node
|
||||
states[i].variables = dict(self.variables)
|
||||
states[i].messages_ = list(self.messages_)
|
||||
states[i].cur_role = self.cur_role
|
||||
states[i].chat_template = self.chat_template
|
||||
|
||||
state_group = ProgramStateGroup(states, self)
|
||||
|
||||
return state_group
|
||||
|
||||
##################################
|
||||
########## Internal API ##########
|
||||
##################################
|
||||
|
||||
def _append_node(self, other: SglExpr):
|
||||
self.nodes.append(other)
|
||||
other.prev_node = self.last_node
|
||||
self.last_node = other
|
||||
|
||||
def _execute(self, other: SglExpr):
|
||||
if isinstance(other, str):
|
||||
other = SglConstantText(other)
|
||||
|
||||
other.pid = self.pid
|
||||
|
||||
if isinstance(other, SglConstantText):
|
||||
self._execute_fill(other)
|
||||
elif isinstance(other, SglGen):
|
||||
self._execute_gen(other)
|
||||
elif isinstance(other, SglSelect):
|
||||
self._execute_select(other)
|
||||
elif isinstance(other, SglExprList):
|
||||
for x in other.expr_list:
|
||||
self._execute(x)
|
||||
elif isinstance(other, SglRoleBegin):
|
||||
self._execute_role_begin(other)
|
||||
elif isinstance(other, SglRoleEnd):
|
||||
self._execute_role_end(other)
|
||||
elif isinstance(other, SglVarScopeBegin):
|
||||
self._execute_var_scope_begin(other)
|
||||
elif isinstance(other, SglVarScopeEnd):
|
||||
self._execute_var_scope_end(other)
|
||||
else:
|
||||
if self.only_trace_prefix:
|
||||
raise StopTracing()
|
||||
else:
|
||||
self._append_node(other)
|
||||
|
||||
return self
|
||||
|
||||
def __iadd__(self, other):
|
||||
self._execute(other)
|
||||
return self
|
||||
|
||||
def _execute_fill(self, expr: SglConstantText):
|
||||
if isinstance(expr, str):
|
||||
expr = SglConstantText(expr)
|
||||
self._append_node(expr)
|
||||
|
||||
def _execute_gen(self, expr: SglGen):
|
||||
name = expr.name if expr.name is not None else "gen_" + str(len(self.variables))
|
||||
new_node = SglVariable(name, source=expr)
|
||||
self.variables[name] = new_node
|
||||
self._append_node(expr)
|
||||
|
||||
def _execute_select(self, expr: SglSelect):
|
||||
name = (
|
||||
expr.name if expr.name is not None else "select_" + str(len(self.variables))
|
||||
)
|
||||
new_node = SglVariable(name, source=expr)
|
||||
self.variables[name] = new_node
|
||||
self._append_node(expr)
|
||||
|
||||
def _execute_role_begin(self, expr: SglRoleBegin):
|
||||
assert self.cur_role is None, "Nested roles are not allowed."
|
||||
|
||||
if len(self.messages_) == 0 and expr.role != "system":
|
||||
# Insert default system message
|
||||
default_system = self.chat_template.default_system_prompt
|
||||
if default_system:
|
||||
self._execute_role_begin(SglRoleBegin("system"))
|
||||
self._execute_fill(default_system)
|
||||
self._execute_role_end(SglRoleEnd("system"))
|
||||
|
||||
self.cur_role = expr.role
|
||||
|
||||
prefix, suffix = self.chat_template.get_prefix_and_suffix(
|
||||
expr.role, self.messages_
|
||||
)
|
||||
|
||||
self._execute_fill(prefix)
|
||||
|
||||
def _execute_role_end(self, expr: SglRoleEnd):
|
||||
prefix, suffix = self.chat_template.get_prefix_and_suffix(
|
||||
expr.role, self.messages_
|
||||
)
|
||||
|
||||
self._execute_fill(suffix)
|
||||
|
||||
self.messages_.append({"role": expr.role, "content": ""})
|
||||
|
||||
self.cur_role = None
|
||||
|
||||
def _execute_var_scope_end(self, expr: SglVarScopeEnd):
|
||||
new_node = SglVariable(name, source=self.last_node)
|
||||
self.variables[name] = new_node
|
||||
|
||||
def get_var(self, name):
|
||||
ret = self.arguments.get(name, None)
|
||||
if ret is not None:
|
||||
return ret
|
||||
|
||||
v = self.variables[name]
|
||||
return SglVariable(v.name, v.source)
|
||||
|
||||
def flatten_nodes(self):
|
||||
def traverse(cur):
|
||||
if isinstance(cur, SglExprList):
|
||||
for child in cur.expr_list:
|
||||
traverse(child)
|
||||
else:
|
||||
ret.append(cur)
|
||||
|
||||
ret = []
|
||||
for x in self.nodes:
|
||||
traverse(x)
|
||||
return ret
|
||||
|
||||
def __del__(self):
|
||||
pass
|
||||
|
||||
|
||||
class TracingScope:
|
||||
cur_scope = None
|
||||
|
||||
def __init__(self, tracer_state: TracerProgramState):
|
||||
self.tracer_state = tracer_state
|
||||
self.last_scope = TracingScope.cur_scope
|
||||
|
||||
def __enter__(self):
|
||||
TracingScope.cur_scope = self
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
TracingScope.cur_scope = self.last_scope
|
||||
|
||||
@staticmethod
|
||||
def get_current_scope():
|
||||
return TracingScope.cur_scope
|
||||
|
||||
def add_child_state(self, state: TracerProgramState):
|
||||
cur_scope = self
|
||||
while cur_scope != None:
|
||||
cur_scope.tracer_state.child_states.append(state)
|
||||
cur_scope = cur_scope.last_scope
|
||||
Reference in New Issue
Block a user