# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project from vllm.logger import init_logger logger = init_logger(__name__) IS_GATED=False class MluHijackObject: hijack_objs = [] @classmethod def apply_hijack(cls, obj, org_func, hijack_func, verify_orig_func_exists: bool = False): """ Optional Args: verify_orig_func_exists (bool): If True, verifies that hijack succeeds """ cls.hijack_objs.append((obj, org_func, hijack_func)) if type(org_func) == str: org_func_name = org_func else: if isinstance(org_func, property): split_name = org_func.fget.__name__.split('__') else: split_name = org_func.__name__.split('__') org_func_name = split_name[-1] if org_func_name == "": assert split_name[-2] != "", f"invalid {org_func.__name__} to apply hijack" org_func_name = split_name[-2] + "__" if len(split_name) >= 3 and split_name[-3] == "": org_func_name = "__" + org_func_name if verify_orig_func_exists and not hasattr(obj, org_func_name): raise AttributeError(f"function {org_func_name} is not part of {obj}") setattr(obj, org_func_name, hijack_func) if (verify_orig_func_exists and getattr(obj, org_func_name) is not hijack_func): raise AttributeError( f"function {org_func_name} of {obj} failed to be swapped to {hijack_func}") @classmethod def undo_hijack(cls, obj_ = None, hijack_func_ = None): if obj_ and hijack_func_: for obj, org_func, hijack_func in cls.hijack_objs: if obj_ == obj and hijack_func == hijack_func_: if type(org_func) == str: if hasattr(obj, org_func): delattr(obj, org_func) else: org_func_name = org_func.__name__ setattr(obj, org_func_name, org_func) return for obj, org_func, hijack_func in cls.hijack_objs: if type(org_func) == str: if hasattr(obj, org_func): delattr(obj, org_func) else: org_func_name = org_func.__name__ setattr(obj, org_func_name, org_func) TypedDict = { "hidden_size": 0, "vocab_size": 0, "ffn_inner_size": 0, "moe_inner_size": 0, "layer_num": 0, "moe_layer_num": 0, "head_num": 0, "head_size": 0, "head_num_kv": 0, "tp_num": 0, "shared_expert_intermediate_size": 0, "shared_experts": 0, "qk_nope_head_dim": 0, "qk_rope_head_dim": 0, "q_lora_rank": 0.0, "num_attention_heads": 0, "kv_lora_rank": 0, "v_head_dim": 0, "use_gated_ffn": False, "experts_num": 0, "topk_num": 0, "use_causal_mask": False, "cla_coeffient": 0, "kv_cache_dtype": "", "smooth_quant_type": "", "data_type": "", "model_type": "", "filter_data_type": "", } def set_is_gated(flag): global IS_GATED IS_GATED=flag def get_is_gated(): return IS_GATED