Files
enginex-mlu370-vllm/vllm-v0.6.2/vllm_mlu/vllm_mlu/mlu_hijack_utils.py
2026-02-04 17:22:39 +08:00

90 lines
3.1 KiB
Python

from vllm.logger import init_logger
import ctypes
logger = init_logger(__name__)
IS_GATED=False
class MluHijackObject:
hijack_objs = []
@classmethod
def apply_hijack(cls, obj, org_func, hijack_func):
cls.hijack_objs.append((obj, org_func, hijack_func))
if type(org_func) == str:
org_func_name = org_func
else:
if isinstance(org_func, property):
split_name = org_func.fget.__name__.split('__')
else:
split_name = org_func.__name__.split('__')
org_func_name = split_name[-1]
if org_func_name == "":
assert split_name[-2] != "", f"invalid {org_func.__name__} to apply hijack"
org_func_name = split_name[-2] + "__"
if len(split_name) >= 3 and split_name[-3] == "":
org_func_name = "__" + org_func_name
setattr(obj, org_func_name, hijack_func)
@classmethod
def undo_hijack(cls, obj_ = None, hijack_func_ = None):
if obj_ and hijack_func_:
for obj, org_func, hijack_func in cls.hijack_objs:
if obj_ == obj and hijack_func == hijack_func_:
if type(org_func) == str:
if hasattr(obj, org_func):
delattr(obj, org_func)
else:
org_func_name = org_func.__name__
setattr(obj, org_func_name, org_func)
return
for obj, org_func, hijack_func in cls.hijack_objs:
if type(org_func) == str:
if hasattr(obj, org_func):
delattr(obj, org_func)
else:
org_func_name = org_func.__name__
setattr(obj, org_func_name, org_func)
class ModelConfig(ctypes.Structure):
_fields_ = [
('hidden_size', ctypes.c_double),
('vocab_size', ctypes.c_double),
('ffn_inner_size', ctypes.c_double),
('moe_inner_size', ctypes.c_double),
('layer_num', ctypes.c_double),
('moe_layer_num', ctypes.c_double),
('head_num', ctypes.c_double),
('head_size', ctypes.c_double),
('head_num_kv', ctypes.c_double),
('tp_num', ctypes.c_double),
('shared_expert_intermediate_size', ctypes.c_double),
('shared_experts', ctypes.c_double),
('qk_nope_head_dim', ctypes.c_double),
('qk_rope_head_dim', ctypes.c_double),
('q_lora_rank', ctypes.c_double),
('num_attention_heads', ctypes.c_double),
('kv_lora_rank', ctypes.c_double),
('v_head_dim', ctypes.c_double),
('use_gated_ffn', ctypes.c_bool),
('experts_num', ctypes.c_int),
('topk_num', ctypes.c_int),
('use_causal_mask', ctypes.c_bool),
('cla_coeffient', ctypes.c_double),
('kv_cache_dtype', ctypes.c_char_p),
('smooth_quant_type', ctypes.c_char_p),
('data_type', ctypes.c_char_p),
('model_type', ctypes.c_char_p),
('filter_data_type', ctypes.c_char_p),
]
def set_is_gated(flag):
global IS_GATED
IS_GATED=flag
def get_is_gated():
return IS_GATED