Use dtype to control generate (#1082)

Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
Liangsheng Yin
2024-08-14 08:58:07 -07:00
committed by GitHub
parent 67c0d832a6
commit a34dd86a7d
12 changed files with 110 additions and 88 deletions

View File

@@ -1,21 +1,23 @@
import json
import warnings
from typing import List, Optional
from sglang.global_config import global_config
from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template_by_model_path
from sglang.lang.choices import (
ChoicesDecision,
ChoicesSamplingMethod,
token_length_normalized,
)
from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SglSamplingParams
from sglang.lang.ir import (
REGEX_BOOL,
REGEX_FLOAT,
REGEX_INT,
REGEX_STR,
SglSamplingParams,
)
from sglang.utils import http_request
class RuntimeEndpoint(BaseBackend):
def __init__(
self,
base_url: str,
@@ -95,32 +97,52 @@ class RuntimeEndpoint(BaseBackend):
)
self._assert_success(res)
def _handle_dtype_to_regex(self, sampling_params: SglSamplingParams):
if sampling_params.dtype is None:
return
if sampling_params.stop == ():
sampling_params.stop = []
dtype_regex = None
if sampling_params.dtype in ["int", int]:
dtype_regex = REGEX_INT
sampling_params.stop.extend([" ", "\n"])
elif sampling_params.dtype in ["float", float]:
dtype_regex = REGEX_FLOAT
sampling_params.stop.extend([" ", "\n"])
elif sampling_params.dtype in ["str", str]:
dtype_regex = REGEX_STR
elif sampling_params.dtype in ["bool", bool]:
dtype_regex = REGEX_BOOL
else:
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
if dtype_regex is not None and sampling_params.regex is not None:
warnings.warn(
f"Both dtype and regex are set. Only dtype will be used. dtype: {sampling_params.dtype}, regex: {sampling_params.regex}"
)
sampling_params.regex = dtype_regex
def generate(
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
):
if sampling_params.dtype is None:
data = {
"text": s.text_,
"sampling_params": {
"skip_special_tokens": global_config.skip_special_tokens_in_output,
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
**sampling_params.to_srt_kwargs(),
},
}
elif sampling_params.dtype in [int, "int"]:
data = {
"text": s.text_,
"sampling_params": {
"skip_special_tokens": global_config.skip_special_tokens_in_output,
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
"dtype": "int",
**sampling_params.to_srt_kwargs(),
},
}
else:
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
self._handle_dtype_to_regex(sampling_params)
data = {
"text": s.text_,
"sampling_params": {
"skip_special_tokens": global_config.skip_special_tokens_in_output,
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
**sampling_params.to_srt_kwargs(),
},
}
for item in [
"return_logprob",
@@ -151,27 +173,16 @@ class RuntimeEndpoint(BaseBackend):
s: StreamExecutor,
sampling_params: SglSamplingParams,
):
if sampling_params.dtype is None:
data = {
"text": s.text_,
"sampling_params": {
"skip_special_tokens": global_config.skip_special_tokens_in_output,
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
**sampling_params.to_srt_kwargs(),
},
}
elif sampling_params.dtype in [int, "int"]:
data = {
"text": s.text_,
"sampling_params": {
"skip_special_tokens": global_config.skip_special_tokens_in_output,
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
"dtype": "int",
**sampling_params.to_srt_kwargs(),
},
}
else:
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
self._handle_dtype_to_regex(sampling_params)
data = {
"text": s.text_,
"sampling_params": {
"skip_special_tokens": global_config.skip_special_tokens_in_output,
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
**sampling_params.to_srt_kwargs(),
},
}
for item in [
"return_logprob",

View File

@@ -8,10 +8,10 @@ from typing import List, Optional, Union
from sglang.global_config import global_config
from sglang.lang.choices import ChoicesSamplingMethod
REGEX_INT = r"[-+]?[0-9]+"
REGEX_FLOAT = r"[-+]?[0-9]*\.?[0-9]+"
REGEX_INT = r"[-+]?[0-9]+[ \n]*"
REGEX_FLOAT = r"[-+]?[0-9]*\.?[0-9]+[ \n]*"
REGEX_BOOL = r"(True|False)"
REGEX_STRING = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg
REGEX_STR = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg
@dataclasses.dataclass