Frontend language separate reasoning support (#6031)
This commit is contained in:
@@ -15,6 +15,7 @@ from sglang.api import (
|
||||
get_server_info,
|
||||
image,
|
||||
select,
|
||||
separate_reasoning,
|
||||
set_default_backend,
|
||||
system,
|
||||
system_begin,
|
||||
@@ -54,6 +55,7 @@ __all__ = [
|
||||
"get_server_info",
|
||||
"image",
|
||||
"select",
|
||||
"separate_reasoning",
|
||||
"set_default_backend",
|
||||
"system",
|
||||
"system_begin",
|
||||
|
||||
@@ -15,6 +15,7 @@ from sglang.lang.ir import (
|
||||
SglRoleBegin,
|
||||
SglRoleEnd,
|
||||
SglSelect,
|
||||
SglSeparateReasoning,
|
||||
SglVideo,
|
||||
)
|
||||
|
||||
@@ -277,3 +278,9 @@ def assistant_begin():
|
||||
|
||||
def assistant_end():
|
||||
return SglRoleEnd("assistant")
|
||||
|
||||
|
||||
def separate_reasoning(
|
||||
expr: Optional[SglExpr] = None, model_type: Optional[str] = None
|
||||
):
|
||||
return SglExprList([expr, SglSeparateReasoning(model_type, expr=expr)])
|
||||
|
||||
@@ -26,6 +26,7 @@ from sglang.lang.ir import (
|
||||
SglRoleBegin,
|
||||
SglRoleEnd,
|
||||
SglSelect,
|
||||
SglSeparateReasoning,
|
||||
SglVariable,
|
||||
SglVarScopeBegin,
|
||||
SglVarScopeEnd,
|
||||
@@ -472,6 +473,8 @@ class StreamExecutor:
|
||||
self._execute_concatenate_and_append_kv_cache(other)
|
||||
else:
|
||||
self._execute_concatenate_and_append_text(other)
|
||||
elif isinstance(other, SglSeparateReasoning):
|
||||
self._execute_separate_reasoning(other)
|
||||
else:
|
||||
raise ValueError(f"Unknown type: {type(other)}")
|
||||
|
||||
@@ -724,8 +727,44 @@ class StreamExecutor:
|
||||
src_rids = [state.stream_executor.sid for state in expr.states]
|
||||
self.backend.concatenate_and_append(src_rids, self.sid)
|
||||
|
||||
def _execute_separate_reasoning(self, expr: SglSeparateReasoning):
|
||||
if self.stream:
|
||||
# separate reasoning for stream is not supported
|
||||
return
|
||||
|
||||
if (
|
||||
self.cur_role == "assistant"
|
||||
and self.num_api_spec_tokens is not None
|
||||
and self.backend.is_chat_model
|
||||
):
|
||||
# Execute the stored lazy generation calls
|
||||
self.backend.role_end_generate(self)
|
||||
|
||||
from sglang.srt.reasoning_parser import ReasoningParser
|
||||
|
||||
reasoning_parser = ReasoningParser(expr.model_type)
|
||||
other = expr.expr
|
||||
if not other:
|
||||
return
|
||||
elif isinstance(other, SglGen) or isinstance(other, SglSelect):
|
||||
cur_text = self.get_var(other.name)
|
||||
reasoning, normal_text = reasoning_parser.parse_non_stream(cur_text)
|
||||
reasoning_name = expr.process_name_for_reasoning(other.name)
|
||||
self.set_var(other.name, normal_text)
|
||||
self.set_var(reasoning_name, reasoning)
|
||||
# the variable is ready to be used
|
||||
self.variable_event[reasoning_name].set()
|
||||
self.text_ = self.text_[: self.cur_role_begin_pos] + normal_text
|
||||
elif isinstance(other, SglExprList):
|
||||
for x in other.expr_list:
|
||||
self._execute_separate_reasoning(
|
||||
SglSeparateReasoning(expr.model_type, x)
|
||||
)
|
||||
|
||||
def _init_var_event(self, expr):
|
||||
if isinstance(expr, (SglGen, SglSelect, SglVarScopeBegin)):
|
||||
if isinstance(
|
||||
expr, (SglGen, SglSelect, SglVarScopeBegin, SglSeparateReasoning)
|
||||
):
|
||||
self.variable_event[expr.name] = threading.Event()
|
||||
if self.stream:
|
||||
self.stream_var_event[expr.name] = threading.Event()
|
||||
|
||||
@@ -606,3 +606,30 @@ class SglCommitLazy(SglExpr):
|
||||
|
||||
def __repr__(self):
|
||||
return "CommitLazy()"
|
||||
|
||||
|
||||
class SglSeparateReasoning(SglExpr):
|
||||
def __init__(self, model_type: str, expr: SglExpr):
|
||||
super().__init__()
|
||||
self.model_type = model_type
|
||||
|
||||
self.expr = expr
|
||||
self.name = None
|
||||
self._process_expr(expr)
|
||||
|
||||
def process_name_for_reasoning(self, name):
|
||||
if not name:
|
||||
raise ValueError("name must be provided")
|
||||
return f"{name}_reasoning_content"
|
||||
|
||||
def _process_expr(self, expr):
|
||||
if isinstance(expr, SglGen):
|
||||
self.name = self.process_name_for_reasoning(expr.name)
|
||||
elif isinstance(expr, SglSelect):
|
||||
self.name = self.process_name_for_reasoning(expr.name)
|
||||
elif isinstance(expr, SglExprList):
|
||||
for x in expr.expr_list:
|
||||
self._process_expr(x)
|
||||
|
||||
def __repr__(self):
|
||||
return f"SeparateReasoning(model_type={self.model_type}, name={self.name})"
|
||||
|
||||
Reference in New Issue
Block a user