Files
sglang/python/sglang/lang/ir.py

473 lines
13 KiB
Python

"""The intermediate representation."""
import dataclasses
import inspect
import warnings
from typing import List, Optional, Union
from sglang.global_config import global_config
REGEX_INT = r"[-+]?[0-9]+"
REGEX_FLOAT = r"[-+]?[0-9]*\.?[0-9]+"
REGEX_BOOL = r"(True|False)"
REGEX_STRING = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg
@dataclasses.dataclass
class SglSamplingParams:
max_new_tokens: int = 16
stop: Union[str, List[str]] = ()
temperature: float = 1.0
top_p: float = 1.0
top_k: int = -1 # -1 means disable
frequency_penalty: float = 0.0
presence_penalty: float = 0.0
ignore_eos: bool = False
# for constrained generation, not included in to_xxx_kwargs
dtype: Optional[str] = None
regex: Optional[str] = None
def clone(self):
return SglSamplingParams(
self.max_new_tokens,
self.stop,
self.temperature,
self.top_p,
self.top_k,
self.frequency_penalty,
self.presence_penalty,
)
def to_openai_kwargs(self):
# OpenAI does not support top_k, so we drop it here
if self.regex is not None:
warnings.warn("Regular expression is not supported in the OpenAI backend.")
return {
"max_tokens": self.max_new_tokens,
"stop": self.stop or None,
"temperature": self.temperature,
"top_p": self.top_p,
"frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty,
}
def to_vertexai_kwargs(self):
if self.regex is not None:
warnings.warn(
"Regular expression is not supported in the VertexAI backend."
)
return {
"candidate_count": 1,
"max_output_tokens": self.max_new_tokens,
"stop_sequences": self.stop,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k if self.top_k > 0 else None,
}
def to_anthropic_kwargs(self):
# Anthropic does not support frequency_penalty or presence_penalty, so we drop it here
if self.regex is not None:
warnings.warn(
"Regular expression is not supported in the Anthropic backend."
)
return {
"max_tokens_to_sample": self.max_new_tokens,
"stop_sequences": self.stop,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
}
def to_srt_kwargs(self):
return {
"max_new_tokens": self.max_new_tokens,
"stop": self.stop,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty,
"ignore_eos": self.ignore_eos,
"regex": self.regex,
}
class SglFunction:
def __init__(self, func, bind_arguments=None):
self.func = func
self.bind_arguments = bind_arguments or {}
self.pin_prefix_rid = None
# Parse arguments
argspec = inspect.getfullargspec(func)
assert argspec.args[0] == "s", 'The first argument must be "s"'
self.arg_names = argspec.args[1:]
def bind(self, **kwargs):
assert all(key in self.arg_names for key in kwargs)
new_bind_dict = {**self.bind_arguments, **kwargs}
return SglFunction(self.func, bind_arguments=new_bind_dict)
def run(
self,
*args,
max_new_tokens: int = 16,
stop: Union[str, List[str]] = (),
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
ignore_eos: bool = False,
stream: bool = False,
backend=None,
**kwargs,
):
from sglang.lang.interpreter import run_program
default_sampling_para = SglSamplingParams(
max_new_tokens=max_new_tokens,
stop=stop,
temperature=temperature,
top_p=top_p,
top_k=top_k,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
ignore_eos=ignore_eos,
)
backend = backend or global_config.default_backend
return run_program(self, backend, args, kwargs, default_sampling_para, stream)
def run_batch(
self,
batch_kwargs,
*,
max_new_tokens: int = 16,
stop: Union[str, List[str]] = (),
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
ignore_eos: bool = False,
backend=None,
num_threads: Union[str, int] = "auto",
progress_bar: bool = False,
):
from sglang.lang.interpreter import run_program_batch
assert isinstance(batch_kwargs, (list, tuple))
if len(batch_kwargs) == 0:
return []
assert isinstance(batch_kwargs[0], dict)
default_sampling_para = SglSamplingParams(
max_new_tokens=max_new_tokens,
stop=stop,
temperature=temperature,
top_p=top_p,
top_k=top_k,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
ignore_eos=ignore_eos,
)
backend = backend or global_config.default_backend
return run_program_batch(
self,
backend,
batch_kwargs,
default_sampling_para,
num_threads,
progress_bar,
)
def trace(self, *, backend=None, **kwargs):
from sglang.lang.tracer import trace_program
backend = backend or global_config.default_backend
return trace_program(self, kwargs, backend)
def pin(self, backend=None):
from sglang.lang.interpreter import pin_program
backend = backend or global_config.default_backend
return pin_program(self, backend)
def unpin(self, backend=None):
from sglang.lang.interpreter import unpin_program
backend = backend or global_config.default_backend
return unpin_program(self, backend)
def compile(self, *, backend=None):
from sglang.lang.compiler import compile_func
return compile_func(self, backend)
def __call__(self, *args, **kwargs):
from sglang.lang.tracer import TracingScope
tracing_scope = TracingScope.get_current_scope()
if tracing_scope is None:
return self.run(*args, **kwargs)
else:
kwargs["backend"] = tracing_scope.tracer_state.backend
return self.trace(*args, **kwargs)
class SglExpr:
node_ct = 0
def __init__(self):
self.node_id = SglExpr.node_ct
self.prev_node = None
self.pid = None
SglExpr.node_ct += 1
def __add__(self, other):
if isinstance(other, str):
other = SglConstantText(other)
assert isinstance(other, SglExpr)
return self.concatenate_ir(self, other)
def __radd__(self, other):
if isinstance(other, str):
other = SglConstantText(other)
assert isinstance(other, SglExpr), f"{other}"
return self.concatenate_ir(other, self)
def concatenate_ir(self, a, b):
if isinstance(a, SglExprList):
if isinstance(b, SglExprList):
return SglExprList(a.expr_list + b.expr_list)
else:
return SglExprList(a.expr_list + [b])
elif isinstance(b, SglExprList):
return SglExprList([a] + b.expr_list)
return SglExprList([a, b])
def print_graph_dfs(self):
ret = [""]
visited = set()
def dfs_print(x):
if x is None or x in visited:
return
visited.add(x)
# Print dependency
if x.prev_node is not None:
dfs_print(x.prev_node)
if isinstance(x, SglExprList):
for y in x.expr_list:
dfs_print(y)
# elif isinstance(x, SglRole):
# dfs_print(x.expr)
elif isinstance(x, SglVariable):
dfs_print(x.source)
# Print the node itself
if isinstance(x, (SglFork, SglGetForkItem)):
ret[0] += f"%{x.node_id} = {x}\n"
else:
if x.prev_node is not None:
ret[0] += (
f"%{x.node_id} = %{x.prev_node.node_id} + " + str(x) + "\n"
)
else:
ret[0] += f"%{x.node_id} = " + str(x) + "\n"
dfs_print(self)
return ret[0]
class SglExprList(SglExpr):
def __init__(self, expr_list: List[SglExpr]):
super().__init__()
self.expr_list = expr_list
def __repr__(self):
return f"ExprList({self.expr_list})"
class SglArgument(SglExpr):
def __init__(self, name: str, value: str):
super().__init__()
self.name = name
self.value = value
def __repr__(self):
return f"Argument(name={self.name}, value={repr(self.value)})"
def __len__(self):
return len(self.value)
def __getitem__(self, i):
return self.value[i]
def __int__(self):
return self.value
def __bool__(self):
return self.value
def __format__(self, *args):
raise TypeError(
"Cannot put argument inside a f-string. "
"This is not compatible with the tracer. "
)
class SglImage(SglExpr):
def __init__(self, path):
self.path = path
def __repr__(self) -> str:
return f"SglImage({self.path})"
class SglGen(SglExpr):
def __init__(
self,
name,
max_new_tokens,
stop,
temperature,
top_p,
top_k,
frequency_penalty,
presence_penalty,
ignore_eos,
dtype,
regex,
):
super().__init__()
self.name = name
self.sampling_params = SglSamplingParams(
max_new_tokens=max_new_tokens,
stop=stop,
temperature=temperature,
top_p=top_p,
top_k=top_k,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
ignore_eos=ignore_eos,
dtype=dtype,
regex=regex,
)
def __repr__(self):
return f"Gen('{self.name}')"
class SglConstantText(SglExpr):
def __init__(self, value):
super().__init__()
self.value = value
def __repr__(self):
return f"Constant({repr(self.value)})"
class SglRoleBegin(SglExpr):
def __init__(self, role):
super().__init__()
self.role = role
def __repr__(self):
return f"RoleBegin({self.role})"
class SglRoleEnd(SglExpr):
def __init__(self, role):
super().__init__()
self.role = role
def __repr__(self):
return f"RoleEnd({self.role})"
class SglSelect(SglExpr):
def __init__(self, name, choices, temperature):
super().__init__()
self.name = name
self.choices = choices
self.temperature = temperature
def __repr__(self):
return f"Select({self.name}, choices={self.choices})"
class SglFork(SglExpr):
def __init__(self, number, position_ids_offset=None):
super().__init__()
self.number = number
self.position_ids_offset = position_ids_offset
def __repr__(self):
return (
f"Fork(%{self.prev_node.node_id}, number={self.number}, "
f"position_ids_offset={self.position_ids_offset})"
)
class SglGetForkItem(SglExpr):
def __init__(self, index):
super().__init__()
self.index = index
def __repr__(self):
return f"GetForkItem(%{self.prev_node.node_id}, index={self.index})"
class SglVariable(SglExpr):
def __init__(self, name, source):
super().__init__()
self.name = name
self.source = source
def __repr__(self):
return f"Variable('{self.name}', source=%{self.source.node_id})"
class SglVarScopeBegin(SglExpr):
def __init__(self, name):
super().__init__()
self.name = name
def __repr__(self):
return f"VarScopeBegin('{self.name}')"
class SglVarScopeEnd(SglExpr):
def __init__(self, name):
super().__init__()
self.name = name
def __repr__(self):
return f"VarScopeEnd('{self.name}')"
class SglConcateAndAppend(SglExpr):
def __init__(self, states):
super().__init__()
self.states = states
def __repr__(self):
return f"ConcatenateAndAppend('{self.states}')"
class SglCommitLazy(SglExpr):
def __init__(self):
super().__init__()
def __repr__(self):
return f"CommitLazy()"