添加 K100-vLLM-Patched-v2.0/detect_head_size.py
This commit is contained in:
27
K100-vLLM-Patched-v2.0/detect_head_size.py
Normal file
27
K100-vLLM-Patched-v2.0/detect_head_size.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import json, os, sys
|
||||
|
||||
MODEL_DIR = os.environ.get("MODEL_DIR", "/model")
|
||||
cfg_path = os.path.join(MODEL_DIR, "config.json")
|
||||
|
||||
if not os.path.exists(cfg_path):
|
||||
sys.exit(0)
|
||||
|
||||
with open(cfg_path) as f:
|
||||
cfg = json.load(f)
|
||||
|
||||
head_size = cfg.get("head_dim")
|
||||
if head_size is None:
|
||||
hs = cfg.get("hidden_size")
|
||||
nh = cfg.get("num_attention_heads")
|
||||
if hs and nh:
|
||||
head_size = hs // nh
|
||||
|
||||
if head_size is None:
|
||||
sys.exit(0)
|
||||
|
||||
SUPPORTED = {32, 64, 96, 128, 160, 192, 224, 256}
|
||||
if head_size not in SUPPORTED:
|
||||
print(head_size)
|
||||
sys.exit(2)
|
||||
|
||||
sys.exit(0)
|
||||
Reference in New Issue
Block a user