diff --git a/K100-vLLM-Patched-v2.0/fix_tokenizer.py b/K100-vLLM-Patched-v2.0/fix_tokenizer.py index 80d79f6..10fbba5 100644 --- a/K100-vLLM-Patched-v2.0/fix_tokenizer.py +++ b/K100-vLLM-Patched-v2.0/fix_tokenizer.py @@ -2,6 +2,8 @@ import os import shutil import json import sys +import transformers +import inspect sys.path.insert(0, '/opt') from detect_tokenizer import detect @@ -35,22 +37,28 @@ if os.path.exists(cfg_path): else: cfg = {} -if typ == "fast": - cfg["tokenizer_class"] = "PreTrainedTokenizerFast" -elif typ == "sentencepiece": - cfg["tokenizer_class"] = "LlamaTokenizer" -elif typ == "bpe": - cfg["tokenizer_class"] = "GPT2TokenizerFast" +VALID_CLASSES = { + name for name, obj in inspect.getmembers(transformers) + if inspect.isclass(obj) and "Tokenizer" in name +} + +BAD_CLASSES = {"TokenizersBackend", "TiktokenTokenizer"} + +FALLBACK = { + "fast": "PreTrainedTokenizerFast", + "sentencepiece": "LlamaTokenizer", + "bpe": "GPT2TokenizerFast", +} + +if orig_cls and orig_cls in VALID_CLASSES and orig_cls not in BAD_CLASSES: + print(f"[fix_tokenizer] tokenizer_class '{orig_cls}' is valid, skip override") else: - cfg["tokenizer_class"] = "PreTrainedTokenizerFast" - -bad_classes = [ - "TokenizersBackend", - "TiktokenTokenizer", -] - -if orig_cls in bad_classes: - print(f"[fix] override bad tokenizer_class: {orig_cls} → {cfg['tokenizer_class']}") + fallback = FALLBACK.get(typ, "PreTrainedTokenizerFast") + if orig_cls: + print(f"[fix] override bad tokenizer_class: {orig_cls} → {fallback}") + else: + print(f"[fix] tokenizer_class missing, set to: {fallback}") + cfg["tokenizer_class"] = fallback with open(cfg_path, "w") as f: json.dump(cfg, f)