54 lines
1.7 KiB
Python
54 lines
1.7 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""Config definitions for ExtractHiddenStatesModel, to be used with
|
|
the extract_hidden_states spec decoding method."""
|
|
|
|
import os
|
|
|
|
from transformers import PretrainedConfig
|
|
|
|
|
|
class ExtractHiddenStatesConfig(PretrainedConfig):
|
|
model_type = "extract_hidden_states"
|
|
|
|
def __init__(
|
|
self,
|
|
model: PretrainedConfig | dict | None = None,
|
|
method: str | None = "extract_hidden_states",
|
|
**kwargs,
|
|
):
|
|
assert method == "extract_hidden_states"
|
|
|
|
if isinstance(model, dict):
|
|
model_dict = model
|
|
elif isinstance(model, PretrainedConfig):
|
|
model_dict = model.to_dict()
|
|
else:
|
|
model_dict = {}
|
|
|
|
# Combine: model_dict first, then kwargs override
|
|
combined = {**model_dict, **kwargs}
|
|
# Remove architectures from the base, we'll set it explicitly
|
|
combined = {k: v for k, v in combined.items() if k != "architectures"}
|
|
|
|
combined["architectures"] = ["ExtractHiddenStatesModel"]
|
|
|
|
super().__init__(**combined)
|
|
|
|
@classmethod
|
|
def from_pretrained(
|
|
cls,
|
|
pretrained_model_name_or_path: str | os.PathLike,
|
|
**kwargs,
|
|
) -> "ExtractHiddenStatesConfig":
|
|
config_dict, kwargs = cls.get_config_dict(
|
|
pretrained_model_name_or_path, **kwargs
|
|
)
|
|
return cls.from_dict(config_dict, **kwargs)
|
|
|
|
def to_json_string(self, use_diff: bool = True) -> str:
|
|
# we override use_diff to False as initializing
|
|
# ExtractHiddenStatesConfig with default arguments is not supported
|
|
del use_diff
|
|
return super().to_json_string(use_diff=False)
|