添加 K100-vLLM-Patched-v2.0/detect_head_size.py
This commit is contained in:
27
K100-vLLM-Patched-v2.0/detect_head_size.py
Normal file
27
K100-vLLM-Patched-v2.0/detect_head_size.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
import json, os, sys
|
||||||
|
|
||||||
|
MODEL_DIR = os.environ.get("MODEL_DIR", "/model")
|
||||||
|
cfg_path = os.path.join(MODEL_DIR, "config.json")
|
||||||
|
|
||||||
|
if not os.path.exists(cfg_path):
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
with open(cfg_path) as f:
|
||||||
|
cfg = json.load(f)
|
||||||
|
|
||||||
|
head_size = cfg.get("head_dim")
|
||||||
|
if head_size is None:
|
||||||
|
hs = cfg.get("hidden_size")
|
||||||
|
nh = cfg.get("num_attention_heads")
|
||||||
|
if hs and nh:
|
||||||
|
head_size = hs // nh
|
||||||
|
|
||||||
|
if head_size is None:
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
SUPPORTED = {32, 64, 96, 128, 160, 192, 224, 256}
|
||||||
|
if head_size not in SUPPORTED:
|
||||||
|
print(head_size)
|
||||||
|
sys.exit(2)
|
||||||
|
|
||||||
|
sys.exit(0)
|
||||||
Reference in New Issue
Block a user