Files
2026-01-19 10:38:50 +08:00

66 lines
1.9 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from transformers import PretrainedConfig
class MedusaConfig(PretrainedConfig):
model_type = "medusa"
def __init__(
self,
hidden_size: int = 4096,
vocab_size: int = 32001,
num_heads: int = 5,
num_hidden_layers: int = 1,
max_paths: int = 64,
topk: int = 10,
truncated_vocab_size: int | None = None,
**kwargs,
):
self.hidden_size = hidden_size
self.vocab_size = vocab_size
self.num_heads = num_heads
self.num_hidden_layers = num_hidden_layers
self.max_paths = max_paths
self.topk = topk
self.max_seq_len = int(2**20)
self.truncated_vocab_size = (
vocab_size if truncated_vocab_size is None else truncated_vocab_size
)
if "architectures" not in kwargs:
kwargs["architectures"] = ["MedusaModel"]
super().__init__(**kwargs)
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str | os.PathLike,
**kwargs,
) -> "MedusaConfig":
config_dict, kwargs = cls.get_config_dict(
pretrained_model_name_or_path, **kwargs
)
for k in list(config_dict.keys()):
if "num" in k:
if "heads" in k:
config_dict["num_heads"] = config_dict.pop(k)
elif "layers" in k:
config_dict["num_hidden_layers"] = config_dict.pop(k)
return cls.from_dict(config_dict, **kwargs)
@property
def num_attention_heads(self):
return 0
@property
def num_lookahead_tokens(self):
return self.num_heads
@num_lookahead_tokens.setter
def num_lookahead_tokens(self, num_lookahead_tokens: int):
self.num_heads = num_lookahead_tokens