Json Decode && Mutl-Turns (#4)

This commit is contained in:
Liangsheng Yin
2024-01-15 16:49:29 +08:00
committed by GitHub
parent f652494df1
commit 08ab2a1655
27 changed files with 755 additions and 41 deletions

View File

@@ -6,7 +6,7 @@ from typing import List, Union
from sglang.global_config import global_config
from sglang.lang.interpreter import ProgramState, StreamExecutor, pin_program
from sglang.lang.ir import (
SamplingParams,
SglSamplingParams,
SglArgument,
SglConstantText,
SglExpr,
@@ -140,7 +140,7 @@ class CompiledFunction:
kwargs = {k: SglArgument(k, v) for k, v in kwargs.items()}
kwargs.update(self.function.bind_arguments)
default_sampling_para = SamplingParams(
default_sampling_para = SglSamplingParams(
max_new_tokens=max_new_tokens,
stop=stop,
temperature=temperature,
@@ -173,7 +173,7 @@ class CompiledFunction:
backend = backend or global_config.default_backend
default_sampling_para = SamplingParams(
default_sampling_para = SglSamplingParams(
max_new_tokens=max_new_tokens,
stop=stop,
temperature=temperature,

View File

@@ -292,7 +292,7 @@ class StreamExecutor:
assert isinstance(other, SglExpr), f"{other}"
if isinstance(other, (SglConstantText, SglArgument)):
if isinstance(other, SglConstantText):
self._execute_fill(other.value)
elif isinstance(other, SglGen):
self._execute_gen(other)
@@ -332,8 +332,6 @@ class StreamExecutor:
def _execute_image(self, expr: SglImage):
path = expr.path
if isinstance(path, SglArgument):
path = path.value
base64_data = encode_image_base64(path)
@@ -419,7 +417,7 @@ class StreamExecutor:
"role": expr.role,
"content": [{"type": "text", "text": new_text}],
}
for (image_path, image_base64_data) in self.cur_images:
for image_path, image_base64_data in self.cur_images:
last_msg["content"].append(
{
"type": "image_url",
@@ -480,6 +478,7 @@ class StreamExecutor:
"top_k",
"frequency_penalty",
"presence_penalty",
"ignore_eos",
"dtype",
"regex",
]:

View File

@@ -13,7 +13,7 @@ REGEX_STRING = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg
@dataclasses.dataclass
class SamplingParams:
class SglSamplingParams:
max_new_tokens: int = 16
stop: Union[str, List[str]] = ()
temperature: float = 1.0
@@ -21,13 +21,14 @@ class SamplingParams:
top_k: int = -1 # -1 means disable
frequency_penalty: float = 0.0
presence_penalty: float = 0.0
ignore_eos: bool = False
# for constrained generation, not included in to_xxx_kwargs
dtype: Optional[str] = None
regex: Optional[str] = None
def clone(self):
return SamplingParams(
return SglSamplingParams(
self.max_new_tokens,
self.stop,
self.temperature,
@@ -67,6 +68,7 @@ class SamplingParams:
"top_k": self.top_k,
"frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty,
"ignore_eos": self.ignore_eos,
"regex": self.regex,
}
@@ -98,13 +100,14 @@ class SglFunction:
top_k: int = -1,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
ignore_eos: bool = False,
stream: bool = False,
backend=None,
**kwargs,
):
from sglang.lang.interpreter import run_program
default_sampling_para = SamplingParams(
default_sampling_para = SglSamplingParams(
max_new_tokens=max_new_tokens,
stop=stop,
temperature=temperature,
@@ -112,9 +115,9 @@ class SglFunction:
top_k=top_k,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
ignore_eos=ignore_eos,
)
backend = backend or global_config.default_backend
kwargs = {k: SglArgument(k, v) for k, v in kwargs.items()}
return run_program(self, backend, args, kwargs, default_sampling_para, stream)
def run_batch(
@@ -128,6 +131,7 @@ class SglFunction:
top_k: int = -1,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
ignore_eos: bool = False,
backend=None,
num_threads: Union[str, int] = "auto",
progress_bar: bool = False,
@@ -139,7 +143,7 @@ class SglFunction:
return []
assert isinstance(batch_kwargs[0], dict)
default_sampling_para = SamplingParams(
default_sampling_para = SglSamplingParams(
max_new_tokens=max_new_tokens,
stop=stop,
temperature=temperature,
@@ -147,11 +151,9 @@ class SglFunction:
top_k=top_k,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
ignore_eos=ignore_eos,
)
backend = backend or global_config.default_backend
batch_kwargs = [
{k: SglArgument(k, v) for k, v in kwargs.items()} for kwargs in batch_kwargs
]
return run_program_batch(
self,
backend,
@@ -321,12 +323,13 @@ class SglGen(SglExpr):
top_k,
frequency_penalty,
presence_penalty,
ignore_eos,
dtype,
regex,
):
super().__init__()
self.name = name
self.sampling_params = SamplingParams(
self.sampling_params = SglSamplingParams(
max_new_tokens=max_new_tokens,
stop=stop,
temperature=temperature,
@@ -334,6 +337,7 @@ class SglGen(SglExpr):
top_k=top_k,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
ignore_eos=ignore_eos,
dtype=dtype,
regex=regex,
)

View File

@@ -40,7 +40,8 @@ def extract_prefix_by_tracing(program, backend):
try:
with TracingScope(tracer):
tracer.ret_value = program.func(tracer, **arguments)
except StopTracing:
except (StopTracing, TypeError):
# Some exceptions may not be catched
pass
# Run and cache prefix