27 lines
568 B
Python
27 lines
568 B
Python
import json, os, sys
|
|
|
|
MODEL_DIR = os.environ.get("MODEL_DIR", "/model")
|
|
cfg_path = os.path.join(MODEL_DIR, "config.json")
|
|
|
|
if not os.path.exists(cfg_path):
|
|
sys.exit(0)
|
|
|
|
with open(cfg_path) as f:
|
|
cfg = json.load(f)
|
|
|
|
head_size = cfg.get("head_dim")
|
|
if head_size is None:
|
|
hs = cfg.get("hidden_size")
|
|
nh = cfg.get("num_attention_heads")
|
|
if hs and nh:
|
|
head_size = hs // nh
|
|
|
|
if head_size is None:
|
|
sys.exit(0)
|
|
|
|
SUPPORTED = {32, 64, 96, 128, 160, 192, 224, 256}
|
|
if head_size not in SUPPORTED:
|
|
print(head_size)
|
|
sys.exit(2)
|
|
|
|
sys.exit(0) |