[Minor] improve code style

This commit is contained in:
Lianmin Zheng
2024-06-03 18:11:34 -07:00
parent 9f009261f2
commit 3bc01ac137
4 changed files with 67 additions and 17 deletions

View File

@@ -3,7 +3,8 @@
import json
import os
import warnings
from typing import List, Optional, Tuple, Union
import functools
from typing import Optional, Union, AbstractSet, Collection, Literal
from huggingface_hub import snapshot_download
from transformers import (
@@ -177,10 +178,57 @@ def get_processor(
class TiktokenTokenizer:
def __init__(self, tokenizer_path):
import xlm.tokenizers.tiktoken_wrapper as tiktoken_wrapper
tokenizer = tiktoken_wrapper.Encoding.from_xtok_json("xtok-json", tokenizer_path)
import tiktoken
PAT_STR_B = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
# Read JSON
name = "tmp-json"
with open(tokenizer_path, "rb") as fin:
tok_dict = json.load(fin)
mergeable_ranks = {
bytes(item["bytes"]): item["token"] for item in tok_dict["regular_tokens"]
}
special_tokens = {
bytes(item["bytes"]).decode(): item["token"] for item in tok_dict["special_tokens"]
}
assert tok_dict["word_split"] == "V1"
kwargs = {
"name": name,
"pat_str": tok_dict.get("pat_str", PAT_STR_B),
"mergeable_ranks": mergeable_ranks,
"special_tokens": special_tokens,
}
if "default_allowed_special" in tok_dict:
default_allowed_special = set(
[bytes(bytes_list).decode() for bytes_list in tok_dict["default_allowed_special"]]
)
else:
default_allowed_special = None
if "vocab_size" in tok_dict:
kwargs["explicit_n_vocab"] = tok_dict["vocab_size"]
tokenizer = tiktoken.Encoding(**kwargs)
tokenizer._default_allowed_special = default_allowed_special or set()
def encode_patched(
self,
text: str,
*,
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), # noqa: B006
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
) -> list[int]:
if isinstance(allowed_special, set):
allowed_special |= self._default_allowed_special
return tiktoken.Encoding.encode(
self, text, allowed_special=allowed_special, disallowed_special=disallowed_special
)
tokenizer.encode = functools.partial(encode_patched, tokenizer)
# Convert to HF interface
self.tokenizer = tokenizer
self.eos_token_id = tokenizer.eos_token
self.eos_token_id = tokenizer._special_tokens["<|eos|>"]
self.vocab_size = tokenizer.n_vocab
def encode(self, x, add_special_tokens=False):
@@ -190,6 +238,8 @@ class TiktokenTokenizer:
return self.tokenizer.decode(x)
def batch_decode(self, batch, skip_special_tokens=True, spaces_between_special_tokens=False):
if isinstance(batch[0], int):
batch = [[x] for x in batch]
return self.tokenizer.decode_batch(batch)
def convert_ids_to_tokens(self, index):

View File

@@ -88,9 +88,9 @@ def call_generate_srt_raw(prompt, temperature, max_tokens, stop=None, url=None):
return pred
def call_generate_xinfer(prompt, temperature, max_tokens, stop=None, url=None):
def call_generate_ginfer(prompt, temperature, max_tokens, stop=None, url=None):
import grpc
from xlm.proto import sampler_pb2, sampler_pb2_grpc
from ginfer import sampler_pb2, sampler_pb2_grpc
sampler_channel = grpc.insecure_channel(url.replace("http://", ""))
sampler = sampler_pb2_grpc.SamplerStub(sampler_channel)
@@ -255,7 +255,7 @@ def add_common_other_args_and_parse(parser):
"vllm",
"outlines",
"lightllm",
"xinfer",
"ginfer",
"guidance",
"lmql",
"srt-raw",
@@ -276,7 +276,7 @@ def add_common_other_args_and_parse(parser):
"lightllm": 22000,
"lmql": 23000,
"srt-raw": 30000,
"xinfer": 9988,
"ginfer": 9988,
}
args.port = default_port.get(args.backend, None)
return args
@@ -312,8 +312,8 @@ def _get_call_generate(args):
return partial(call_generate_vllm, url=f"{args.host}:{args.port}/generate")
elif args.backend == "srt-raw":
return partial(call_generate_srt_raw, url=f"{args.host}:{args.port}/generate")
elif args.backend == "xinfer":
return partial(call_generate_xinfer, url=f"{args.host}:{args.port}")
elif args.backend == "ginfer":
return partial(call_generate_ginfer, url=f"{args.host}:{args.port}")
elif args.backend == "outlines":
return partial(call_generate_outlines, url=f"{args.host}:{args.port}/generate")
elif args.backend == "guidance":