Files
enginex-mlu590-vllm/vllm_mlu/mlu_hijack_utils.py

105 lines
3.2 KiB
Python
Raw Permalink Normal View History

2026-04-24 09:50:34 +08:00
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from vllm.logger import init_logger
logger = init_logger(__name__)
IS_GATED=False
class MluHijackObject:
hijack_objs = []
@classmethod
def apply_hijack(cls, obj, org_func, hijack_func,
verify_orig_func_exists: bool = False):
"""
Optional Args:
verify_orig_func_exists (bool): If True, verifies that hijack succeeds
"""
cls.hijack_objs.append((obj, org_func, hijack_func))
if type(org_func) == str:
org_func_name = org_func
else:
if isinstance(org_func, property):
split_name = org_func.fget.__name__.split('__')
else:
split_name = org_func.__name__.split('__')
org_func_name = split_name[-1]
if org_func_name == "":
assert split_name[-2] != "", f"invalid {org_func.__name__} to apply hijack"
org_func_name = split_name[-2] + "__"
if len(split_name) >= 3 and split_name[-3] == "":
org_func_name = "__" + org_func_name
if verify_orig_func_exists and not hasattr(obj, org_func_name):
raise AttributeError(f"function {org_func_name} is not part of {obj}")
setattr(obj, org_func_name, hijack_func)
if (verify_orig_func_exists and getattr(obj, org_func_name) is not hijack_func):
raise AttributeError(
f"function {org_func_name} of {obj} failed to be swapped to {hijack_func}")
@classmethod
def undo_hijack(cls, obj_ = None, hijack_func_ = None):
if obj_ and hijack_func_:
for obj, org_func, hijack_func in cls.hijack_objs:
if obj_ == obj and hijack_func == hijack_func_:
if type(org_func) == str:
if hasattr(obj, org_func):
delattr(obj, org_func)
else:
org_func_name = org_func.__name__
setattr(obj, org_func_name, org_func)
return
for obj, org_func, hijack_func in cls.hijack_objs:
if type(org_func) == str:
if hasattr(obj, org_func):
delattr(obj, org_func)
else:
org_func_name = org_func.__name__
setattr(obj, org_func_name, org_func)
TypedDict = {
"hidden_size": 0,
"vocab_size": 0,
"ffn_inner_size": 0,
"moe_inner_size": 0,
"layer_num": 0,
"moe_layer_num": 0,
"head_num": 0,
"head_size": 0,
"head_num_kv": 0,
"tp_num": 0,
"shared_expert_intermediate_size": 0,
"shared_experts": 0,
"qk_nope_head_dim": 0,
"qk_rope_head_dim": 0,
"q_lora_rank": 0.0,
"num_attention_heads": 0,
"kv_lora_rank": 0,
"v_head_dim": 0,
"use_gated_ffn": False,
"experts_num": 0,
"topk_num": 0,
"use_causal_mask": False,
"cla_coeffient": 0,
"kv_cache_dtype": "",
"smooth_quant_type": "",
"data_type": "",
"model_type": "",
"filter_data_type": "",
}
def set_is_gated(flag):
global IS_GATED
IS_GATED=flag
def get_is_gated():
return IS_GATED