105 lines
3.2 KiB
Python
105 lines
3.2 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
|
|
|
from vllm.logger import init_logger
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
IS_GATED=False
|
|
|
|
class MluHijackObject:
|
|
hijack_objs = []
|
|
|
|
@classmethod
|
|
def apply_hijack(cls, obj, org_func, hijack_func,
|
|
verify_orig_func_exists: bool = False):
|
|
"""
|
|
Optional Args:
|
|
verify_orig_func_exists (bool): If True, verifies that hijack succeeds
|
|
"""
|
|
cls.hijack_objs.append((obj, org_func, hijack_func))
|
|
|
|
if type(org_func) == str:
|
|
org_func_name = org_func
|
|
else:
|
|
if isinstance(org_func, property):
|
|
split_name = org_func.fget.__name__.split('__')
|
|
else:
|
|
split_name = org_func.__name__.split('__')
|
|
org_func_name = split_name[-1]
|
|
if org_func_name == "":
|
|
assert split_name[-2] != "", f"invalid {org_func.__name__} to apply hijack"
|
|
org_func_name = split_name[-2] + "__"
|
|
if len(split_name) >= 3 and split_name[-3] == "":
|
|
org_func_name = "__" + org_func_name
|
|
|
|
if verify_orig_func_exists and not hasattr(obj, org_func_name):
|
|
raise AttributeError(f"function {org_func_name} is not part of {obj}")
|
|
|
|
setattr(obj, org_func_name, hijack_func)
|
|
|
|
if (verify_orig_func_exists and getattr(obj, org_func_name) is not hijack_func):
|
|
raise AttributeError(
|
|
f"function {org_func_name} of {obj} failed to be swapped to {hijack_func}")
|
|
|
|
|
|
@classmethod
|
|
def undo_hijack(cls, obj_ = None, hijack_func_ = None):
|
|
if obj_ and hijack_func_:
|
|
for obj, org_func, hijack_func in cls.hijack_objs:
|
|
if obj_ == obj and hijack_func == hijack_func_:
|
|
if type(org_func) == str:
|
|
if hasattr(obj, org_func):
|
|
delattr(obj, org_func)
|
|
else:
|
|
org_func_name = org_func.__name__
|
|
setattr(obj, org_func_name, org_func)
|
|
return
|
|
for obj, org_func, hijack_func in cls.hijack_objs:
|
|
if type(org_func) == str:
|
|
if hasattr(obj, org_func):
|
|
delattr(obj, org_func)
|
|
else:
|
|
org_func_name = org_func.__name__
|
|
setattr(obj, org_func_name, org_func)
|
|
|
|
|
|
TypedDict = {
|
|
"hidden_size": 0,
|
|
"vocab_size": 0,
|
|
"ffn_inner_size": 0,
|
|
"moe_inner_size": 0,
|
|
"layer_num": 0,
|
|
"moe_layer_num": 0,
|
|
"head_num": 0,
|
|
"head_size": 0,
|
|
"head_num_kv": 0,
|
|
"tp_num": 0,
|
|
"shared_expert_intermediate_size": 0,
|
|
"shared_experts": 0,
|
|
"qk_nope_head_dim": 0,
|
|
"qk_rope_head_dim": 0,
|
|
"q_lora_rank": 0.0,
|
|
"num_attention_heads": 0,
|
|
"kv_lora_rank": 0,
|
|
"v_head_dim": 0,
|
|
"use_gated_ffn": False,
|
|
"experts_num": 0,
|
|
"topk_num": 0,
|
|
"use_causal_mask": False,
|
|
"cla_coeffient": 0,
|
|
"kv_cache_dtype": "",
|
|
"smooth_quant_type": "",
|
|
"data_type": "",
|
|
"model_type": "",
|
|
"filter_data_type": "",
|
|
}
|
|
|
|
|
|
def set_is_gated(flag):
|
|
global IS_GATED
|
|
IS_GATED=flag
|
|
|
|
def get_is_gated():
|
|
return IS_GATED
|