diff --git a/README.md b/README.md index 9ae67d5bf..ee93680e6 100644 --- a/README.md +++ b/README.md @@ -433,6 +433,24 @@ for out in state.text_iter(): print(out, end="", flush=True) ``` +#### Roles + +Use `sgl.system`, `sgl.user` and `sgl.assistant` to set roles when using Chat models. You can also define more complex role prompts using begin and end tokens. + +```python +@sgl.function +def chat_example(s): + s += sgl.system("You are a helpful assistant.") + # Same as: s += s.system("You are a helpful assistant.") + + with s.user(): + s += "Question: What is the capital of France?" + + s += sgl.assistant_begin() + s += "Answer: " + sgl.gen(max_tokens=100, stop="\n") + s += sgl.assistant_end() +``` + #### Tips and Implementation Details - The `choices` argument in `sgl.gen` is implemented by computing the [token-length normalized log probabilities](https://blog.eleuther.ai/multiple-choice-normalization/) of all choices and selecting the one with the highest probability. - The `regex` argument in `sgl.gen` is implemented through autoregressive decoding with logit bias masking, according to the constraints set by the regex. It is compatible with `temperature=0` and `temperature != 0`. diff --git a/python/sglang/__init__.py b/python/sglang/__init__.py index b03d3d250..413ab9e7c 100644 --- a/python/sglang/__init__.py +++ b/python/sglang/__init__.py @@ -14,6 +14,8 @@ from sglang.api import ( select, set_default_backend, system, + system_begin, + system_end, user, user_begin, user_end, @@ -60,4 +62,6 @@ __all__ = [ "user_end", "assistant_begin", "assistant_end", + "system_begin", + "system_end", ] diff --git a/python/sglang/api.py b/python/sglang/api.py index 70f992b14..e6b6715a8 100644 --- a/python/sglang/api.py +++ b/python/sglang/api.py @@ -210,6 +210,14 @@ def assistant(expr: Optional[SglExpr] = None): return _role_common("assistant", expr) +def system_begin(): + return SglRoleBegin("system") + + +def system_end(): + return SglRoleEnd("system") + + def user_begin(): return SglRoleBegin("user") diff --git a/python/sglang/lang/interpreter.py b/python/sglang/lang/interpreter.py index ddf755ca2..61f0a0259 100644 --- a/python/sglang/lang/interpreter.py +++ b/python/sglang/lang/interpreter.py @@ -705,9 +705,9 @@ class ProgramState: def _role_common(self, name: str, expr: Optional[SglExpr] = None): if expr is not None: - self.stream_executor.submit( - SglExprList([SglRoleBegin(name), expr, SglRoleEnd(name)]) - ) + role_expr = SglExprList([SglRoleBegin(name), expr, SglRoleEnd(name)]) + self.stream_executor.submit(role_expr) + return role_expr else: @contextmanager