Files
K100-vllm/K100-vLLM-Patched-v2.0/fix_tokenizer.py

66 lines
1.6 KiB
Python

import os
import shutil
import json
import sys
import transformers
import inspect
sys.path.insert(0, '/opt')
from detect_tokenizer import detect
MODEL_DIR = os.environ.get("MODEL_DIR", "/model")
OUT_DIR = os.environ.get("FIX_TOKENIZER_DIR", "/tmp/fixed_tokenizer")
os.makedirs(OUT_DIR, exist_ok=True)
def copy_if_exists(name):
src = os.path.join(MODEL_DIR, name)
if os.path.exists(src):
shutil.copy(src, OUT_DIR)
for f in [
"tokenizer.json",
"tokenizer_config.json",
"special_tokens_map.json",
"vocab.json",
"merges.txt",
"tokenizer.model",
]:
copy_if_exists(f)
typ, orig_cls = detect(MODEL_DIR)
cfg_path = os.path.join(OUT_DIR, "tokenizer_config.json")
if os.path.exists(cfg_path):
with open(cfg_path) as f:
cfg = json.load(f)
else:
cfg = {}
VALID_CLASSES = {
name for name, obj in inspect.getmembers(transformers)
if inspect.isclass(obj) and "Tokenizer" in name
}
BAD_CLASSES = {"TokenizersBackend", "TiktokenTokenizer"}
FALLBACK = {
"fast": "PreTrainedTokenizerFast",
"sentencepiece": "LlamaTokenizer",
"bpe": "GPT2TokenizerFast",
}
if orig_cls and orig_cls in VALID_CLASSES and orig_cls not in BAD_CLASSES:
print(f"[fix_tokenizer] tokenizer_class '{orig_cls}' is valid, skip override")
else:
fallback = FALLBACK.get(typ, "PreTrainedTokenizerFast")
if orig_cls:
print(f"[fix] override bad tokenizer_class: {orig_cls}{fallback}")
else:
print(f"[fix] tokenizer_class missing, set to: {fallback}")
cfg["tokenizer_class"] = fallback
with open(cfg_path, "w") as f:
json.dump(cfg, f)
print(f"[fix_tokenizer] done → {OUT_DIR}")