Improve logging & add logit cap (#471)
This commit is contained in:
@@ -84,6 +84,9 @@ def get_tokenizer(
|
||||
tokenizer_revision: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
||||
if tokenizer_name.endswith(".json"):
|
||||
return TiktokenTokenizer(tokenizer_name)
|
||||
|
||||
"""Gets a tokenizer for the given model name via Huggingface."""
|
||||
if is_multimodal_model(tokenizer_name):
|
||||
processor = get_processor(
|
||||
@@ -170,3 +173,24 @@ def get_processor(
|
||||
**kwargs,
|
||||
)
|
||||
return processor
|
||||
|
||||
|
||||
class TiktokenTokenizer:
|
||||
def __init__(self, tokenizer_path):
|
||||
import xlm.tokenizers.tiktoken_wrapper as tiktoken_wrapper
|
||||
tokenizer = tiktoken_wrapper.Encoding.from_xtok_json("xtok-json", tokenizer_path)
|
||||
self.tokenizer = tokenizer
|
||||
self.eos_token_id = tokenizer.eos_token
|
||||
self.vocab_size = tokenizer.n_vocab
|
||||
|
||||
def encode(self, x):
|
||||
return self.tokenizer.encode(x)
|
||||
|
||||
def decode(self, x):
|
||||
return self.tokenizer.decode(x)
|
||||
|
||||
def batch_decode(self, batch, skip_special_tokens, spaces_between_special_tokens):
|
||||
return self.tokenizer.decode_batch(batch)
|
||||
|
||||
def convert_ids_to_tokens(self, index):
|
||||
return self.tokenizer.decode_single_token_bytes(index).decode("utf-8", errors="ignore")
|
||||
Reference in New Issue
Block a user