[Model] Support DeepSeek-V4
This commit is contained in:
104
vllm_mlu/mlu_hijack_utils.py
Normal file
104
vllm_mlu/mlu_hijack_utils.py
Normal file
@@ -0,0 +1,104 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
IS_GATED=False
|
||||
|
||||
class MluHijackObject:
|
||||
hijack_objs = []
|
||||
|
||||
@classmethod
|
||||
def apply_hijack(cls, obj, org_func, hijack_func,
|
||||
verify_orig_func_exists: bool = False):
|
||||
"""
|
||||
Optional Args:
|
||||
verify_orig_func_exists (bool): If True, verifies that hijack succeeds
|
||||
"""
|
||||
cls.hijack_objs.append((obj, org_func, hijack_func))
|
||||
|
||||
if type(org_func) == str:
|
||||
org_func_name = org_func
|
||||
else:
|
||||
if isinstance(org_func, property):
|
||||
split_name = org_func.fget.__name__.split('__')
|
||||
else:
|
||||
split_name = org_func.__name__.split('__')
|
||||
org_func_name = split_name[-1]
|
||||
if org_func_name == "":
|
||||
assert split_name[-2] != "", f"invalid {org_func.__name__} to apply hijack"
|
||||
org_func_name = split_name[-2] + "__"
|
||||
if len(split_name) >= 3 and split_name[-3] == "":
|
||||
org_func_name = "__" + org_func_name
|
||||
|
||||
if verify_orig_func_exists and not hasattr(obj, org_func_name):
|
||||
raise AttributeError(f"function {org_func_name} is not part of {obj}")
|
||||
|
||||
setattr(obj, org_func_name, hijack_func)
|
||||
|
||||
if (verify_orig_func_exists and getattr(obj, org_func_name) is not hijack_func):
|
||||
raise AttributeError(
|
||||
f"function {org_func_name} of {obj} failed to be swapped to {hijack_func}")
|
||||
|
||||
|
||||
@classmethod
|
||||
def undo_hijack(cls, obj_ = None, hijack_func_ = None):
|
||||
if obj_ and hijack_func_:
|
||||
for obj, org_func, hijack_func in cls.hijack_objs:
|
||||
if obj_ == obj and hijack_func == hijack_func_:
|
||||
if type(org_func) == str:
|
||||
if hasattr(obj, org_func):
|
||||
delattr(obj, org_func)
|
||||
else:
|
||||
org_func_name = org_func.__name__
|
||||
setattr(obj, org_func_name, org_func)
|
||||
return
|
||||
for obj, org_func, hijack_func in cls.hijack_objs:
|
||||
if type(org_func) == str:
|
||||
if hasattr(obj, org_func):
|
||||
delattr(obj, org_func)
|
||||
else:
|
||||
org_func_name = org_func.__name__
|
||||
setattr(obj, org_func_name, org_func)
|
||||
|
||||
|
||||
TypedDict = {
|
||||
"hidden_size": 0,
|
||||
"vocab_size": 0,
|
||||
"ffn_inner_size": 0,
|
||||
"moe_inner_size": 0,
|
||||
"layer_num": 0,
|
||||
"moe_layer_num": 0,
|
||||
"head_num": 0,
|
||||
"head_size": 0,
|
||||
"head_num_kv": 0,
|
||||
"tp_num": 0,
|
||||
"shared_expert_intermediate_size": 0,
|
||||
"shared_experts": 0,
|
||||
"qk_nope_head_dim": 0,
|
||||
"qk_rope_head_dim": 0,
|
||||
"q_lora_rank": 0.0,
|
||||
"num_attention_heads": 0,
|
||||
"kv_lora_rank": 0,
|
||||
"v_head_dim": 0,
|
||||
"use_gated_ffn": False,
|
||||
"experts_num": 0,
|
||||
"topk_num": 0,
|
||||
"use_causal_mask": False,
|
||||
"cla_coeffient": 0,
|
||||
"kv_cache_dtype": "",
|
||||
"smooth_quant_type": "",
|
||||
"data_type": "",
|
||||
"model_type": "",
|
||||
"filter_data_type": "",
|
||||
}
|
||||
|
||||
|
||||
def set_is_gated(flag):
|
||||
global IS_GATED
|
||||
IS_GATED=flag
|
||||
|
||||
def get_is_gated():
|
||||
return IS_GATED
|
||||
Reference in New Issue
Block a user