add src
This commit is contained in:
33
llm_utils.py
Normal file
33
llm_utils.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
def __init__(self, model_path: str):
|
||||
self.hf_config = PretrainedConfig.from_pretrained(model_path)
|
||||
|
||||
def model_type(self):
|
||||
return self.hf_config.model_type
|
||||
|
||||
def max_model_len(self):
|
||||
derived_max_model_len = float("inf")
|
||||
possible_keys = [
|
||||
# OPT
|
||||
"max_position_embeddings",
|
||||
# GPT-2
|
||||
"n_positions",
|
||||
# MPT
|
||||
"max_seq_len",
|
||||
# ChatGLM2
|
||||
"seq_length",
|
||||
# Others
|
||||
"max_sequence_length",
|
||||
"max_seq_length",
|
||||
"seq_len",
|
||||
]
|
||||
for key in possible_keys:
|
||||
max_len_key = getattr(self.hf_config, key, None)
|
||||
if max_len_key is not None:
|
||||
derived_max_model_len = min(derived_max_model_len, max_len_key)
|
||||
if derived_max_model_len == float("inf"):
|
||||
return None
|
||||
return derived_max_model_len
|
||||
Reference in New Issue
Block a user