"""Utilities for Huggingface Transformers.""" import functools import json import os import warnings from typing import AbstractSet, Collection, Literal, Optional, Union from huggingface_hub import snapshot_download from transformers import ( AutoConfig, AutoProcessor, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast, ) from sglang.srt.utils import is_multimodal_model def download_from_hf(model_path: str): if os.path.exists(model_path): return model_path return snapshot_download(model_path, allow_patterns=["*.json", "*.bin", "*.model"]) def get_config_json(model_path: str): with open(os.path.join(model_path, "config.json")) as f: config = json.load(f) return config def get_config( model: str, trust_remote_code: bool, revision: Optional[str] = None, model_overide_args: Optional[dict] = None, ): config = AutoConfig.from_pretrained( model, trust_remote_code=trust_remote_code, revision=revision ) if model_overide_args: config.update(model_overide_args) return config # Models don't use the same configuration key for determining the maximum # context length. Store them here so we can sanely check them. # NOTE: The ordering here is important. Some models have two of these and we # have a preference for which value gets used. CONTEXT_LENGTH_KEYS = [ "max_sequence_length", "seq_length", "max_position_embeddings", "max_seq_len", "model_max_length", ] def get_context_length(config): """Get the context length of a model from a huggingface model config.""" rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling: rope_scaling_factor = config.rope_scaling["factor"] else: rope_scaling_factor = 1 for key in CONTEXT_LENGTH_KEYS: val = getattr(config, key, None) if val is not None: return int(rope_scaling_factor * val) return 2048 # A fast LLaMA tokenizer with the pre-processed `tokenizer.json` file. _FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer" def get_tokenizer( tokenizer_name: str, *args, tokenizer_mode: str = "auto", trust_remote_code: bool = False, tokenizer_revision: Optional[str] = None, **kwargs, ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: if tokenizer_name.endswith(".json"): return TiktokenTokenizer(tokenizer_name) if tokenizer_name.endswith(".model"): return SentencePieceTokenizer(tokenizer_name) """Gets a tokenizer for the given model name via Huggingface.""" if is_multimodal_model(tokenizer_name): processor = get_processor( tokenizer_name, *args, trust_remote_code=trust_remote_code, tokenizer_revision=tokenizer_revision, **kwargs, ) tokenizer = processor.tokenizer return tokenizer if tokenizer_mode == "slow": if kwargs.get("use_fast", False): raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.") kwargs["use_fast"] = False if ( "llama" in tokenizer_name.lower() and kwargs.get("use_fast", True) and tokenizer_name != _FAST_LLAMA_TOKENIZER ): pass # warnings.warn( # "For some LLaMA V1 models, initializing the fast tokenizer may " # "take a long time. To reduce the initialization time, consider " # f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original " # "tokenizer." # ) try: tokenizer = AutoTokenizer.from_pretrained( tokenizer_name, *args, trust_remote_code=trust_remote_code, tokenizer_revision=tokenizer_revision, **kwargs, ) except TypeError as e: # The LLaMA tokenizer causes a protobuf error in some environments. err_msg = ( "Failed to load the tokenizer. If you are using a LLaMA V1 model " f"consider using '{_FAST_LLAMA_TOKENIZER}' instead of the " "original tokenizer." ) raise RuntimeError(err_msg) from e except ValueError as e: # If the error pertains to the tokenizer class not existing or not # currently being imported, suggest using the --trust-remote-code flag. if not trust_remote_code and ( "does not exist or is not currently imported." in str(e) or "requires you to execute the tokenizer file" in str(e) ): err_msg = ( "Failed to load the tokenizer. If the tokenizer is a custom " "tokenizer not yet available in the HuggingFace transformers " "library, consider setting `trust_remote_code=True` in LLM " "or using the `--trust-remote-code` flag in the CLI." ) raise RuntimeError(err_msg) from e else: raise e if not isinstance(tokenizer, PreTrainedTokenizerFast): warnings.warn( "Using a slow tokenizer. This might cause a significant " "slowdown. Consider using a fast tokenizer instead." ) return tokenizer def get_processor( tokenizer_name: str, *args, tokenizer_mode: str = "auto", trust_remote_code: bool = False, tokenizer_revision: Optional[str] = None, **kwargs, ): processor = AutoProcessor.from_pretrained( tokenizer_name, *args, trust_remote_code=trust_remote_code, tokenizer_revision=tokenizer_revision, **kwargs, ) return processor class TiktokenTokenizer: def __init__(self, tokenizer_path): import tiktoken from jinja2 import Template PAT_STR_B = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" # Read JSON name = "tmp-json" with open(tokenizer_path, "rb") as fin: tok_dict = json.load(fin) mergeable_ranks = { bytes(item["bytes"]): item["token"] for item in tok_dict["regular_tokens"] } special_tokens = { bytes(item["bytes"]).decode(): item["token"] for item in tok_dict["special_tokens"] } assert tok_dict["word_split"] == "V1" kwargs = { "name": name, "pat_str": tok_dict.get("pat_str", PAT_STR_B), "mergeable_ranks": mergeable_ranks, "special_tokens": special_tokens, } if "default_allowed_special" in tok_dict: default_allowed_special = set( [ bytes(bytes_list).decode() for bytes_list in tok_dict["default_allowed_special"] ] ) else: default_allowed_special = None if "vocab_size" in tok_dict: kwargs["explicit_n_vocab"] = tok_dict["vocab_size"] tokenizer = tiktoken.Encoding(**kwargs) tokenizer._default_allowed_special = default_allowed_special or set() tokenizer._default_allowed_special |= {"<|separator|>"} def encode_patched( self, text: str, *, allowed_special: Union[ Literal["all"], AbstractSet[str] ] = set(), # noqa: B006 disallowed_special: Union[Literal["all"], Collection[str]] = "all", ) -> list[int]: if isinstance(allowed_special, set): allowed_special |= self._default_allowed_special return tiktoken.Encoding.encode( self, text, allowed_special=allowed_special, disallowed_special=disallowed_special, ) tokenizer.encode = functools.partial(encode_patched, tokenizer) # Convert to HF interface self.tokenizer = tokenizer self.eos_token_id = tokenizer._special_tokens["<|eos|>"] self.vocab_size = tokenizer.n_vocab self.chat_template = Template( "{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}" ) def encode(self, x, add_special_tokens=False): return self.tokenizer.encode(x) def decode(self, x): return self.tokenizer.decode(x) def batch_decode( self, batch, skip_special_tokens=True, spaces_between_special_tokens=False ): if isinstance(batch[0], int): batch = [[x] for x in batch] return self.tokenizer.decode_batch(batch) def apply_chat_template(self, messages, tokenize, add_generation_prompt): ret = self.chat_template.render( messages=messages, add_generation_prompt=add_generation_prompt ) return self.encode(ret) if tokenize else ret class SentencePieceTokenizer: def __init__(self, tokenizer_path): import sentencepiece as spm from jinja2 import Template tokenizer = spm.SentencePieceProcessor(model_file=tokenizer_path) # Convert to HF interface self.tokenizer = tokenizer self.eos_token_id = tokenizer.eos_id() self.vocab_size = tokenizer.vocab_size() self.chat_template = Template( "{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}" ) def encode(self, x, add_special_tokens=False): return self.tokenizer.encode(x) def decode(self, x): return self.tokenizer.decode(x) def batch_decode( self, batch, skip_special_tokens=True, spaces_between_special_tokens=False ): if isinstance(batch[0], int): batch = [[x] for x in batch] return self.tokenizer.decode(batch) def apply_chat_template(self, messages, tokenize, add_generation_prompt): ret = self.chat_template.render( messages=messages, add_generation_prompt=add_generation_prompt ) return self.encode(ret) if tokenize else ret