Json Decode && Mutl-Turns (#4)
This commit is contained in:
@@ -6,7 +6,7 @@ from typing import List, Union
|
||||
from sglang.global_config import global_config
|
||||
from sglang.lang.interpreter import ProgramState, StreamExecutor, pin_program
|
||||
from sglang.lang.ir import (
|
||||
SamplingParams,
|
||||
SglSamplingParams,
|
||||
SglArgument,
|
||||
SglConstantText,
|
||||
SglExpr,
|
||||
@@ -140,7 +140,7 @@ class CompiledFunction:
|
||||
kwargs = {k: SglArgument(k, v) for k, v in kwargs.items()}
|
||||
kwargs.update(self.function.bind_arguments)
|
||||
|
||||
default_sampling_para = SamplingParams(
|
||||
default_sampling_para = SglSamplingParams(
|
||||
max_new_tokens=max_new_tokens,
|
||||
stop=stop,
|
||||
temperature=temperature,
|
||||
@@ -173,7 +173,7 @@ class CompiledFunction:
|
||||
|
||||
backend = backend or global_config.default_backend
|
||||
|
||||
default_sampling_para = SamplingParams(
|
||||
default_sampling_para = SglSamplingParams(
|
||||
max_new_tokens=max_new_tokens,
|
||||
stop=stop,
|
||||
temperature=temperature,
|
||||
|
||||
@@ -292,7 +292,7 @@ class StreamExecutor:
|
||||
|
||||
assert isinstance(other, SglExpr), f"{other}"
|
||||
|
||||
if isinstance(other, (SglConstantText, SglArgument)):
|
||||
if isinstance(other, SglConstantText):
|
||||
self._execute_fill(other.value)
|
||||
elif isinstance(other, SglGen):
|
||||
self._execute_gen(other)
|
||||
@@ -332,8 +332,6 @@ class StreamExecutor:
|
||||
|
||||
def _execute_image(self, expr: SglImage):
|
||||
path = expr.path
|
||||
if isinstance(path, SglArgument):
|
||||
path = path.value
|
||||
|
||||
base64_data = encode_image_base64(path)
|
||||
|
||||
@@ -419,7 +417,7 @@ class StreamExecutor:
|
||||
"role": expr.role,
|
||||
"content": [{"type": "text", "text": new_text}],
|
||||
}
|
||||
for (image_path, image_base64_data) in self.cur_images:
|
||||
for image_path, image_base64_data in self.cur_images:
|
||||
last_msg["content"].append(
|
||||
{
|
||||
"type": "image_url",
|
||||
@@ -480,6 +478,7 @@ class StreamExecutor:
|
||||
"top_k",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
"ignore_eos",
|
||||
"dtype",
|
||||
"regex",
|
||||
]:
|
||||
|
||||
@@ -13,7 +13,7 @@ REGEX_STRING = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SamplingParams:
|
||||
class SglSamplingParams:
|
||||
max_new_tokens: int = 16
|
||||
stop: Union[str, List[str]] = ()
|
||||
temperature: float = 1.0
|
||||
@@ -21,13 +21,14 @@ class SamplingParams:
|
||||
top_k: int = -1 # -1 means disable
|
||||
frequency_penalty: float = 0.0
|
||||
presence_penalty: float = 0.0
|
||||
ignore_eos: bool = False
|
||||
|
||||
# for constrained generation, not included in to_xxx_kwargs
|
||||
dtype: Optional[str] = None
|
||||
regex: Optional[str] = None
|
||||
|
||||
def clone(self):
|
||||
return SamplingParams(
|
||||
return SglSamplingParams(
|
||||
self.max_new_tokens,
|
||||
self.stop,
|
||||
self.temperature,
|
||||
@@ -67,6 +68,7 @@ class SamplingParams:
|
||||
"top_k": self.top_k,
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
"presence_penalty": self.presence_penalty,
|
||||
"ignore_eos": self.ignore_eos,
|
||||
"regex": self.regex,
|
||||
}
|
||||
|
||||
@@ -98,13 +100,14 @@ class SglFunction:
|
||||
top_k: int = -1,
|
||||
frequency_penalty: float = 0.0,
|
||||
presence_penalty: float = 0.0,
|
||||
ignore_eos: bool = False,
|
||||
stream: bool = False,
|
||||
backend=None,
|
||||
**kwargs,
|
||||
):
|
||||
from sglang.lang.interpreter import run_program
|
||||
|
||||
default_sampling_para = SamplingParams(
|
||||
default_sampling_para = SglSamplingParams(
|
||||
max_new_tokens=max_new_tokens,
|
||||
stop=stop,
|
||||
temperature=temperature,
|
||||
@@ -112,9 +115,9 @@ class SglFunction:
|
||||
top_k=top_k,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
ignore_eos=ignore_eos,
|
||||
)
|
||||
backend = backend or global_config.default_backend
|
||||
kwargs = {k: SglArgument(k, v) for k, v in kwargs.items()}
|
||||
return run_program(self, backend, args, kwargs, default_sampling_para, stream)
|
||||
|
||||
def run_batch(
|
||||
@@ -128,6 +131,7 @@ class SglFunction:
|
||||
top_k: int = -1,
|
||||
frequency_penalty: float = 0.0,
|
||||
presence_penalty: float = 0.0,
|
||||
ignore_eos: bool = False,
|
||||
backend=None,
|
||||
num_threads: Union[str, int] = "auto",
|
||||
progress_bar: bool = False,
|
||||
@@ -139,7 +143,7 @@ class SglFunction:
|
||||
return []
|
||||
assert isinstance(batch_kwargs[0], dict)
|
||||
|
||||
default_sampling_para = SamplingParams(
|
||||
default_sampling_para = SglSamplingParams(
|
||||
max_new_tokens=max_new_tokens,
|
||||
stop=stop,
|
||||
temperature=temperature,
|
||||
@@ -147,11 +151,9 @@ class SglFunction:
|
||||
top_k=top_k,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
ignore_eos=ignore_eos,
|
||||
)
|
||||
backend = backend or global_config.default_backend
|
||||
batch_kwargs = [
|
||||
{k: SglArgument(k, v) for k, v in kwargs.items()} for kwargs in batch_kwargs
|
||||
]
|
||||
return run_program_batch(
|
||||
self,
|
||||
backend,
|
||||
@@ -321,12 +323,13 @@ class SglGen(SglExpr):
|
||||
top_k,
|
||||
frequency_penalty,
|
||||
presence_penalty,
|
||||
ignore_eos,
|
||||
dtype,
|
||||
regex,
|
||||
):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self.sampling_params = SamplingParams(
|
||||
self.sampling_params = SglSamplingParams(
|
||||
max_new_tokens=max_new_tokens,
|
||||
stop=stop,
|
||||
temperature=temperature,
|
||||
@@ -334,6 +337,7 @@ class SglGen(SglExpr):
|
||||
top_k=top_k,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
ignore_eos=ignore_eos,
|
||||
dtype=dtype,
|
||||
regex=regex,
|
||||
)
|
||||
|
||||
@@ -40,7 +40,8 @@ def extract_prefix_by_tracing(program, backend):
|
||||
try:
|
||||
with TracingScope(tracer):
|
||||
tracer.ret_value = program.func(tracer, **arguments)
|
||||
except StopTracing:
|
||||
except (StopTracing, TypeError):
|
||||
# Some exceptions may not be catched
|
||||
pass
|
||||
|
||||
# Run and cache prefix
|
||||
|
||||
Reference in New Issue
Block a user