Files
bi_150-vllm/vllm/transformers_utils/configs/extract_hidden_states.py

54 lines
1.7 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Config definitions for ExtractHiddenStatesModel, to be used with
the extract_hidden_states spec decoding method."""
import os
from transformers import PretrainedConfig
class ExtractHiddenStatesConfig(PretrainedConfig):
model_type = "extract_hidden_states"
def __init__(
self,
model: PretrainedConfig | dict | None = None,
method: str | None = "extract_hidden_states",
**kwargs,
):
assert method == "extract_hidden_states"
if isinstance(model, dict):
model_dict = model
elif isinstance(model, PretrainedConfig):
model_dict = model.to_dict()
else:
model_dict = {}
# Combine: model_dict first, then kwargs override
combined = {**model_dict, **kwargs}
# Remove architectures from the base, we'll set it explicitly
combined = {k: v for k, v in combined.items() if k != "architectures"}
combined["architectures"] = ["ExtractHiddenStatesModel"]
super().__init__(**combined)
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str | os.PathLike,
**kwargs,
) -> "ExtractHiddenStatesConfig":
config_dict, kwargs = cls.get_config_dict(
pretrained_model_name_or_path, **kwargs
)
return cls.from_dict(config_dict, **kwargs)
def to_json_string(self, use_diff: bool = True) -> str:
# we override use_diff to False as initializing
# ExtractHiddenStatesConfig with default arguments is not supported
del use_diff
return super().to_json_string(use_diff=False)