[Minor] Many cleanup (#1357)
This commit is contained in:
@@ -16,6 +16,7 @@ limitations under the License.
|
||||
"""Cache for the compressed finite state machine."""
|
||||
|
||||
from outlines.fsm.json_schema import build_regex_from_schema
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from sglang.srt.constrained import RegexGuide, TransformerTokenizer
|
||||
from sglang.srt.constrained.base_tool_cache import BaseToolCache
|
||||
@@ -28,12 +29,9 @@ class FSMCache(BaseToolCache):
|
||||
tokenizer_args_dict,
|
||||
enable=True,
|
||||
skip_tokenizer_init=False,
|
||||
json_schema_mode=False,
|
||||
):
|
||||
super().__init__(enable=enable)
|
||||
|
||||
self.json_schema_mode = json_schema_mode
|
||||
|
||||
if (
|
||||
skip_tokenizer_init
|
||||
or tokenizer_path.endswith(".json")
|
||||
@@ -42,44 +40,37 @@ class FSMCache(BaseToolCache):
|
||||
# Do not support TiktokenTokenizer or SentencePieceTokenizer
|
||||
return
|
||||
|
||||
from importlib.metadata import version
|
||||
tokenizer_args_dict.setdefault("padding_side", "left")
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_args_dict)
|
||||
try:
|
||||
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
|
||||
except AttributeError:
|
||||
# FIXME: tmp fix for chatglm2 & chatglm3 (pad_token_id=0)
|
||||
origin_pad_token_id = tokenizer.pad_token_id
|
||||
|
||||
if version("outlines") >= "0.0.35":
|
||||
from transformers import AutoTokenizer
|
||||
def fset(self, value):
|
||||
self._value = value
|
||||
|
||||
tokenizer_args_dict.setdefault("padding_side", "left")
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
tokenizer_path, **tokenizer_args_dict
|
||||
type(tokenizer).pad_token_id = property(
|
||||
fget=type(tokenizer).pad_token_id.fget, fset=fset
|
||||
)
|
||||
try:
|
||||
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
|
||||
except AttributeError:
|
||||
# FIXME: tmp fix for chatglm2 & chatglm3 (pad_token_id=0)
|
||||
origin_pad_token_id = tokenizer.pad_token_id
|
||||
|
||||
def fset(self, value):
|
||||
self._value = value
|
||||
|
||||
type(tokenizer).pad_token_id = property(
|
||||
fget=type(tokenizer).pad_token_id.fget, fset=fset
|
||||
)
|
||||
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
|
||||
self.outlines_tokenizer.tokenizer.pad_token_id = origin_pad_token_id
|
||||
self.outlines_tokenizer.pad_token_id = origin_pad_token_id
|
||||
self.outlines_tokenizer.pad_token = (
|
||||
self.outlines_tokenizer.tokenizer.pad_token
|
||||
)
|
||||
self.outlines_tokenizer.vocabulary = (
|
||||
self.outlines_tokenizer.tokenizer.get_vocab()
|
||||
)
|
||||
else:
|
||||
self.outlines_tokenizer = TransformerTokenizer(
|
||||
tokenizer_path, **tokenizer_args_dict
|
||||
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
|
||||
self.outlines_tokenizer.tokenizer.pad_token_id = origin_pad_token_id
|
||||
self.outlines_tokenizer.pad_token_id = origin_pad_token_id
|
||||
self.outlines_tokenizer.pad_token = (
|
||||
self.outlines_tokenizer.tokenizer.pad_token
|
||||
)
|
||||
self.outlines_tokenizer.vocabulary = (
|
||||
self.outlines_tokenizer.tokenizer.get_vocab()
|
||||
)
|
||||
|
||||
def init_value(self, value):
|
||||
if self.json_schema_mode:
|
||||
regex = build_regex_from_schema(value, whitespace_pattern=r"[\n\t ]*")
|
||||
return RegexGuide(regex, self.outlines_tokenizer), regex
|
||||
def init_value(self, key):
|
||||
key_type, key_string = key
|
||||
if key_type == "json":
|
||||
regex = build_regex_from_schema(key_string, whitespace_pattern=r"[\n\t ]*")
|
||||
elif key_type == "regex":
|
||||
regex = key_string
|
||||
else:
|
||||
return RegexGuide(value, self.outlines_tokenizer)
|
||||
raise ValueError(f"Invalid key_type: {key_type}")
|
||||
|
||||
return RegexGuide(regex, self.outlines_tokenizer), regex
|
||||
|
||||
Reference in New Issue
Block a user