90 lines
3.1 KiB
Python
90 lines
3.1 KiB
Python
from vllm.logger import init_logger
|
|
import ctypes
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
IS_GATED=False
|
|
|
|
class MluHijackObject:
|
|
hijack_objs = []
|
|
|
|
@classmethod
|
|
def apply_hijack(cls, obj, org_func, hijack_func):
|
|
cls.hijack_objs.append((obj, org_func, hijack_func))
|
|
|
|
if type(org_func) == str:
|
|
org_func_name = org_func
|
|
else:
|
|
if isinstance(org_func, property):
|
|
split_name = org_func.fget.__name__.split('__')
|
|
else:
|
|
split_name = org_func.__name__.split('__')
|
|
org_func_name = split_name[-1]
|
|
if org_func_name == "":
|
|
assert split_name[-2] != "", f"invalid {org_func.__name__} to apply hijack"
|
|
org_func_name = split_name[-2] + "__"
|
|
if len(split_name) >= 3 and split_name[-3] == "":
|
|
org_func_name = "__" + org_func_name
|
|
|
|
setattr(obj, org_func_name, hijack_func)
|
|
|
|
@classmethod
|
|
def undo_hijack(cls, obj_ = None, hijack_func_ = None):
|
|
if obj_ and hijack_func_:
|
|
for obj, org_func, hijack_func in cls.hijack_objs:
|
|
if obj_ == obj and hijack_func == hijack_func_:
|
|
if type(org_func) == str:
|
|
if hasattr(obj, org_func):
|
|
delattr(obj, org_func)
|
|
else:
|
|
org_func_name = org_func.__name__
|
|
setattr(obj, org_func_name, org_func)
|
|
return
|
|
for obj, org_func, hijack_func in cls.hijack_objs:
|
|
if type(org_func) == str:
|
|
if hasattr(obj, org_func):
|
|
delattr(obj, org_func)
|
|
else:
|
|
org_func_name = org_func.__name__
|
|
setattr(obj, org_func_name, org_func)
|
|
|
|
|
|
class ModelConfig(ctypes.Structure):
|
|
_fields_ = [
|
|
('hidden_size', ctypes.c_double),
|
|
('vocab_size', ctypes.c_double),
|
|
('ffn_inner_size', ctypes.c_double),
|
|
('moe_inner_size', ctypes.c_double),
|
|
('layer_num', ctypes.c_double),
|
|
('moe_layer_num', ctypes.c_double),
|
|
('head_num', ctypes.c_double),
|
|
('head_size', ctypes.c_double),
|
|
('head_num_kv', ctypes.c_double),
|
|
('tp_num', ctypes.c_double),
|
|
('shared_expert_intermediate_size', ctypes.c_double),
|
|
('shared_experts', ctypes.c_double),
|
|
('qk_nope_head_dim', ctypes.c_double),
|
|
('qk_rope_head_dim', ctypes.c_double),
|
|
('q_lora_rank', ctypes.c_double),
|
|
('num_attention_heads', ctypes.c_double),
|
|
('kv_lora_rank', ctypes.c_double),
|
|
('v_head_dim', ctypes.c_double),
|
|
('use_gated_ffn', ctypes.c_bool),
|
|
('experts_num', ctypes.c_int),
|
|
('topk_num', ctypes.c_int),
|
|
('use_causal_mask', ctypes.c_bool),
|
|
('cla_coeffient', ctypes.c_double),
|
|
('kv_cache_dtype', ctypes.c_char_p),
|
|
('smooth_quant_type', ctypes.c_char_p),
|
|
('data_type', ctypes.c_char_p),
|
|
('model_type', ctypes.c_char_p),
|
|
('filter_data_type', ctypes.c_char_p),
|
|
]
|
|
|
|
def set_is_gated(flag):
|
|
global IS_GATED
|
|
IS_GATED=flag
|
|
|
|
def get_is_gated():
|
|
return IS_GATED
|