34 lines
986 B
Python
34 lines
986 B
Python
from transformers import PretrainedConfig
|
|
|
|
|
|
class ModelConfig:
|
|
def __init__(self, model_path: str):
|
|
self.hf_config = PretrainedConfig.from_pretrained(model_path)
|
|
|
|
def model_type(self):
|
|
return self.hf_config.model_type
|
|
|
|
def max_model_len(self):
|
|
derived_max_model_len = float("inf")
|
|
possible_keys = [
|
|
# OPT
|
|
"max_position_embeddings",
|
|
# GPT-2
|
|
"n_positions",
|
|
# MPT
|
|
"max_seq_len",
|
|
# ChatGLM2
|
|
"seq_length",
|
|
# Others
|
|
"max_sequence_length",
|
|
"max_seq_length",
|
|
"seq_len",
|
|
]
|
|
for key in possible_keys:
|
|
max_len_key = getattr(self.hf_config, key, None)
|
|
if max_len_key is not None:
|
|
derived_max_model_len = min(derived_max_model_len, max_len_key)
|
|
if derived_max_model_len == float("inf"):
|
|
return None
|
|
return derived_max_model_len
|