Higher priority for user input of max_prefill_tokens & format (#540)

This commit is contained in:
Ying Sheng
2024-06-12 21:48:40 -07:00
committed by GitHub
parent 1374334d38
commit fb9296f0ed
50 changed files with 817 additions and 569 deletions

View File

@@ -1,10 +1,10 @@
"""Utilities for Huggingface Transformers."""
import functools
import json
import os
import warnings
import functools
from typing import Optional, Union, AbstractSet, Collection, Literal
from typing import AbstractSet, Collection, Literal, Optional, Union
from huggingface_hub import snapshot_download
from transformers import (
@@ -179,6 +179,7 @@ def get_processor(
class TiktokenTokenizer:
def __init__(self, tokenizer_path):
import tiktoken
PAT_STR_B = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
# Read JSON
@@ -190,7 +191,8 @@ class TiktokenTokenizer:
bytes(item["bytes"]): item["token"] for item in tok_dict["regular_tokens"]
}
special_tokens = {
bytes(item["bytes"]).decode(): item["token"] for item in tok_dict["special_tokens"]
bytes(item["bytes"]).decode(): item["token"]
for item in tok_dict["special_tokens"]
}
assert tok_dict["word_split"] == "V1"
@@ -202,7 +204,10 @@ class TiktokenTokenizer:
}
if "default_allowed_special" in tok_dict:
default_allowed_special = set(
[bytes(bytes_list).decode() for bytes_list in tok_dict["default_allowed_special"]]
[
bytes(bytes_list).decode()
for bytes_list in tok_dict["default_allowed_special"]
]
)
else:
default_allowed_special = None
@@ -216,14 +221,20 @@ class TiktokenTokenizer:
self,
text: str,
*,
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), # noqa: B006
allowed_special: Union[
Literal["all"], AbstractSet[str]
] = set(), # noqa: B006
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
) -> list[int]:
if isinstance(allowed_special, set):
allowed_special |= self._default_allowed_special
return tiktoken.Encoding.encode(
self, text, allowed_special=allowed_special, disallowed_special=disallowed_special
self,
text,
allowed_special=allowed_special,
disallowed_special=disallowed_special,
)
tokenizer.encode = functools.partial(encode_patched, tokenizer)
# Convert to HF interface
@@ -237,10 +248,14 @@ class TiktokenTokenizer:
def decode(self, x):
return self.tokenizer.decode(x)
def batch_decode(self, batch, skip_special_tokens=True, spaces_between_special_tokens=False):
def batch_decode(
self, batch, skip_special_tokens=True, spaces_between_special_tokens=False
):
if isinstance(batch[0], int):
batch = [[x] for x in batch]
return self.tokenizer.decode_batch(batch)
def convert_ids_to_tokens(self, index):
return self.tokenizer.decode_single_token_bytes(index).decode("utf-8", errors="ignore")
return self.tokenizer.decode_single_token_bytes(index).decode(
"utf-8", errors="ignore"
)