Use dtype to control generate (#1082)
Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
@@ -1,21 +1,23 @@
|
||||
import json
|
||||
import warnings
|
||||
from typing import List, Optional
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.lang.backend.base_backend import BaseBackend
|
||||
from sglang.lang.chat_template import get_chat_template_by_model_path
|
||||
from sglang.lang.choices import (
|
||||
ChoicesDecision,
|
||||
ChoicesSamplingMethod,
|
||||
token_length_normalized,
|
||||
)
|
||||
from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
|
||||
from sglang.lang.interpreter import StreamExecutor
|
||||
from sglang.lang.ir import SglSamplingParams
|
||||
from sglang.lang.ir import (
|
||||
REGEX_BOOL,
|
||||
REGEX_FLOAT,
|
||||
REGEX_INT,
|
||||
REGEX_STR,
|
||||
SglSamplingParams,
|
||||
)
|
||||
from sglang.utils import http_request
|
||||
|
||||
|
||||
class RuntimeEndpoint(BaseBackend):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
@@ -95,32 +97,52 @@ class RuntimeEndpoint(BaseBackend):
|
||||
)
|
||||
self._assert_success(res)
|
||||
|
||||
def _handle_dtype_to_regex(self, sampling_params: SglSamplingParams):
|
||||
if sampling_params.dtype is None:
|
||||
return
|
||||
|
||||
if sampling_params.stop == ():
|
||||
sampling_params.stop = []
|
||||
|
||||
dtype_regex = None
|
||||
if sampling_params.dtype in ["int", int]:
|
||||
|
||||
dtype_regex = REGEX_INT
|
||||
sampling_params.stop.extend([" ", "\n"])
|
||||
elif sampling_params.dtype in ["float", float]:
|
||||
|
||||
dtype_regex = REGEX_FLOAT
|
||||
sampling_params.stop.extend([" ", "\n"])
|
||||
elif sampling_params.dtype in ["str", str]:
|
||||
|
||||
dtype_regex = REGEX_STR
|
||||
elif sampling_params.dtype in ["bool", bool]:
|
||||
|
||||
dtype_regex = REGEX_BOOL
|
||||
else:
|
||||
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
|
||||
|
||||
if dtype_regex is not None and sampling_params.regex is not None:
|
||||
warnings.warn(
|
||||
f"Both dtype and regex are set. Only dtype will be used. dtype: {sampling_params.dtype}, regex: {sampling_params.regex}"
|
||||
)
|
||||
|
||||
sampling_params.regex = dtype_regex
|
||||
|
||||
def generate(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
if sampling_params.dtype is None:
|
||||
data = {
|
||||
"text": s.text_,
|
||||
"sampling_params": {
|
||||
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
||||
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
|
||||
**sampling_params.to_srt_kwargs(),
|
||||
},
|
||||
}
|
||||
elif sampling_params.dtype in [int, "int"]:
|
||||
data = {
|
||||
"text": s.text_,
|
||||
"sampling_params": {
|
||||
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
||||
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
|
||||
"dtype": "int",
|
||||
**sampling_params.to_srt_kwargs(),
|
||||
},
|
||||
}
|
||||
else:
|
||||
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
|
||||
self._handle_dtype_to_regex(sampling_params)
|
||||
data = {
|
||||
"text": s.text_,
|
||||
"sampling_params": {
|
||||
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
||||
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
|
||||
**sampling_params.to_srt_kwargs(),
|
||||
},
|
||||
}
|
||||
|
||||
for item in [
|
||||
"return_logprob",
|
||||
@@ -151,27 +173,16 @@ class RuntimeEndpoint(BaseBackend):
|
||||
s: StreamExecutor,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
if sampling_params.dtype is None:
|
||||
data = {
|
||||
"text": s.text_,
|
||||
"sampling_params": {
|
||||
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
||||
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
|
||||
**sampling_params.to_srt_kwargs(),
|
||||
},
|
||||
}
|
||||
elif sampling_params.dtype in [int, "int"]:
|
||||
data = {
|
||||
"text": s.text_,
|
||||
"sampling_params": {
|
||||
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
||||
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
|
||||
"dtype": "int",
|
||||
**sampling_params.to_srt_kwargs(),
|
||||
},
|
||||
}
|
||||
else:
|
||||
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
|
||||
self._handle_dtype_to_regex(sampling_params)
|
||||
|
||||
data = {
|
||||
"text": s.text_,
|
||||
"sampling_params": {
|
||||
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
||||
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
|
||||
**sampling_params.to_srt_kwargs(),
|
||||
},
|
||||
}
|
||||
|
||||
for item in [
|
||||
"return_logprob",
|
||||
|
||||
@@ -8,10 +8,10 @@ from typing import List, Optional, Union
|
||||
from sglang.global_config import global_config
|
||||
from sglang.lang.choices import ChoicesSamplingMethod
|
||||
|
||||
REGEX_INT = r"[-+]?[0-9]+"
|
||||
REGEX_FLOAT = r"[-+]?[0-9]*\.?[0-9]+"
|
||||
REGEX_INT = r"[-+]?[0-9]+[ \n]*"
|
||||
REGEX_FLOAT = r"[-+]?[0-9]*\.?[0-9]+[ \n]*"
|
||||
REGEX_BOOL = r"(True|False)"
|
||||
REGEX_STRING = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg
|
||||
REGEX_STR = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
||||
Reference in New Issue
Block a user