[Minor] improve code style
This commit is contained in:
@@ -3,7 +3,8 @@
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
import functools
|
||||
from typing import Optional, Union, AbstractSet, Collection, Literal
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import (
|
||||
@@ -177,10 +178,57 @@ def get_processor(
|
||||
|
||||
class TiktokenTokenizer:
|
||||
def __init__(self, tokenizer_path):
|
||||
import xlm.tokenizers.tiktoken_wrapper as tiktoken_wrapper
|
||||
tokenizer = tiktoken_wrapper.Encoding.from_xtok_json("xtok-json", tokenizer_path)
|
||||
import tiktoken
|
||||
PAT_STR_B = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
|
||||
|
||||
# Read JSON
|
||||
name = "tmp-json"
|
||||
with open(tokenizer_path, "rb") as fin:
|
||||
tok_dict = json.load(fin)
|
||||
|
||||
mergeable_ranks = {
|
||||
bytes(item["bytes"]): item["token"] for item in tok_dict["regular_tokens"]
|
||||
}
|
||||
special_tokens = {
|
||||
bytes(item["bytes"]).decode(): item["token"] for item in tok_dict["special_tokens"]
|
||||
}
|
||||
assert tok_dict["word_split"] == "V1"
|
||||
|
||||
kwargs = {
|
||||
"name": name,
|
||||
"pat_str": tok_dict.get("pat_str", PAT_STR_B),
|
||||
"mergeable_ranks": mergeable_ranks,
|
||||
"special_tokens": special_tokens,
|
||||
}
|
||||
if "default_allowed_special" in tok_dict:
|
||||
default_allowed_special = set(
|
||||
[bytes(bytes_list).decode() for bytes_list in tok_dict["default_allowed_special"]]
|
||||
)
|
||||
else:
|
||||
default_allowed_special = None
|
||||
if "vocab_size" in tok_dict:
|
||||
kwargs["explicit_n_vocab"] = tok_dict["vocab_size"]
|
||||
|
||||
tokenizer = tiktoken.Encoding(**kwargs)
|
||||
tokenizer._default_allowed_special = default_allowed_special or set()
|
||||
|
||||
def encode_patched(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), # noqa: B006
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
||||
) -> list[int]:
|
||||
if isinstance(allowed_special, set):
|
||||
allowed_special |= self._default_allowed_special
|
||||
return tiktoken.Encoding.encode(
|
||||
self, text, allowed_special=allowed_special, disallowed_special=disallowed_special
|
||||
)
|
||||
tokenizer.encode = functools.partial(encode_patched, tokenizer)
|
||||
|
||||
# Convert to HF interface
|
||||
self.tokenizer = tokenizer
|
||||
self.eos_token_id = tokenizer.eos_token
|
||||
self.eos_token_id = tokenizer._special_tokens["<|eos|>"]
|
||||
self.vocab_size = tokenizer.n_vocab
|
||||
|
||||
def encode(self, x, add_special_tokens=False):
|
||||
@@ -190,6 +238,8 @@ class TiktokenTokenizer:
|
||||
return self.tokenizer.decode(x)
|
||||
|
||||
def batch_decode(self, batch, skip_special_tokens=True, spaces_between_special_tokens=False):
|
||||
if isinstance(batch[0], int):
|
||||
batch = [[x] for x in batch]
|
||||
return self.tokenizer.decode_batch(batch)
|
||||
|
||||
def convert_ids_to_tokens(self, index):
|
||||
|
||||
@@ -88,9 +88,9 @@ def call_generate_srt_raw(prompt, temperature, max_tokens, stop=None, url=None):
|
||||
return pred
|
||||
|
||||
|
||||
def call_generate_xinfer(prompt, temperature, max_tokens, stop=None, url=None):
|
||||
def call_generate_ginfer(prompt, temperature, max_tokens, stop=None, url=None):
|
||||
import grpc
|
||||
from xlm.proto import sampler_pb2, sampler_pb2_grpc
|
||||
from ginfer import sampler_pb2, sampler_pb2_grpc
|
||||
|
||||
sampler_channel = grpc.insecure_channel(url.replace("http://", ""))
|
||||
sampler = sampler_pb2_grpc.SamplerStub(sampler_channel)
|
||||
@@ -255,7 +255,7 @@ def add_common_other_args_and_parse(parser):
|
||||
"vllm",
|
||||
"outlines",
|
||||
"lightllm",
|
||||
"xinfer",
|
||||
"ginfer",
|
||||
"guidance",
|
||||
"lmql",
|
||||
"srt-raw",
|
||||
@@ -276,7 +276,7 @@ def add_common_other_args_and_parse(parser):
|
||||
"lightllm": 22000,
|
||||
"lmql": 23000,
|
||||
"srt-raw": 30000,
|
||||
"xinfer": 9988,
|
||||
"ginfer": 9988,
|
||||
}
|
||||
args.port = default_port.get(args.backend, None)
|
||||
return args
|
||||
@@ -312,8 +312,8 @@ def _get_call_generate(args):
|
||||
return partial(call_generate_vllm, url=f"{args.host}:{args.port}/generate")
|
||||
elif args.backend == "srt-raw":
|
||||
return partial(call_generate_srt_raw, url=f"{args.host}:{args.port}/generate")
|
||||
elif args.backend == "xinfer":
|
||||
return partial(call_generate_xinfer, url=f"{args.host}:{args.port}")
|
||||
elif args.backend == "ginfer":
|
||||
return partial(call_generate_ginfer, url=f"{args.host}:{args.port}")
|
||||
elif args.backend == "outlines":
|
||||
return partial(call_generate_outlines, url=f"{args.host}:{args.port}/generate")
|
||||
elif args.backend == "guidance":
|
||||
|
||||
Reference in New Issue
Block a user