Files
2026-01-19 10:38:50 +08:00

236 lines
8.2 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
from transformers import PretrainedConfig, WhisperConfig
from vllm.logger import init_logger
logger = init_logger(__name__)
def adapt_config_dict(
config_dict: dict[str, Any],
defaults: dict[str, Any],
) -> PretrainedConfig:
config_dict = _remap_general_mistral_args(config_dict)
if bool(config_dict.get("quantization")):
config_dict = _remap_mistral_quantization_args(config_dict)
is_moe = bool(config_dict.get("moe"))
is_mistral_large_3 = (
is_moe and (config_dict["moe"].get("num_shared_experts") or 0) > 0
)
if config_dict.get("model_type") == "mamba":
config_dict["architectures"] = ["Mamba2ForCausalLM"]
elif is_moe and is_mistral_large_3:
config_dict = _remap_moe_args(config_dict)
config_dict["model_type"] = "deepseek_v3"
config_dict["architectures"] = ["MistralLarge3ForCausalLM"]
assert "llama_4_scaling" in config_dict, (
"MistralLarge3 expect llama4 scaling config."
)
llama_4_scaling_config_keys = ["original_max_position_embeddings", "beta"]
assert all(
[
key in config_dict["llama_4_scaling"]
for key in llama_4_scaling_config_keys
]
), (
"llama_4_scaling config should define the keys: "
f"{','.join(llama_4_scaling_config_keys)}"
)
elif is_moe:
config_dict["architectures"] = ["MixtralForCausalLM"]
else:
config_dict["architectures"] = ["MistralForCausalLM"]
if bool(config_dict.get("yarn")):
config_dict = _remap_mistral_yarn_args(config_dict)
if bool(config_dict.get("llama_4_scaling")):
llama_4_scaling_config_keys = ["original_max_position_embeddings", "beta"]
assert all(
[
key in config_dict["llama_4_scaling"]
for key in llama_4_scaling_config_keys
]
), (
"llama_4_scaling config should define the keys: "
f"{','.join(llama_4_scaling_config_keys)}"
)
is_vision = (config_dict.get("multimodal") or {}).get(
"vision_encoder_args"
) or config_dict.get("vision_encoder")
is_audio = bool(
((config_dict.get("multimodal") or {}).get("whisper_model_args") or {}).get(
"encoder_args"
)
)
assert not (is_vision and is_audio), "Vision and audio are mutually exclusive"
if is_vision:
config_dict = _remap_mistral_vision_args(config_dict)
if is_audio:
config_dict = _remap_mistral_audio_args(config_dict)
for k, v in defaults.items():
config_dict.setdefault(k, v)
config = PretrainedConfig.from_dict(config_dict)
logger.debug("Initialized config %s", config)
return config
def _remap_mistral_vision_args(config: dict) -> dict:
if config.get("multimodal"):
vision_config = config.pop("multimodal")
else:
vision_config = config.pop("vision_encoder")
quant_config = config.get("quantization_config")
config = {
"model_type": "pixtral",
"architectures": ["PixtralForConditionalGeneration"],
"text_config": PretrainedConfig.from_dict(config),
"vision_config": PretrainedConfig.from_dict(vision_config),
}
if quant_config:
config["quantization_config"] = quant_config
return config
def _remap_mistral_yarn_args(config: dict) -> dict:
yarn_config_map = {
"factor": "factor",
"original_max_position_embeddings": "original_max_position_embeddings",
"beta": "beta_fast",
"alpha": "beta_slow",
"apply_scale": "apply_yarn_scaling",
}
yarn_config = config.get("yarn") or {}
config["rope_parameters"] = {
"rope_type": "yarn",
"mscale_all_dim": 1,
}
if rope_theta := config.pop("rope_theta", None):
config["rope_parameters"]["rope_theta"] = rope_theta
for old_name, new_name in yarn_config_map.items():
if old_name in yarn_config:
config["rope_parameters"][new_name] = yarn_config.pop(old_name)
assert len(yarn_config) == 0, f"Unparsed yarn config: {yarn_config}"
return config
def _remap_general_mistral_args(config: dict) -> dict:
# Mistral key -> HF key
config_mapping = {
"dim": "hidden_size",
"norm_eps": "rms_norm_eps",
"n_kv_heads": "num_key_value_heads",
"n_layers": "num_hidden_layers",
"n_heads": "num_attention_heads",
"hidden_dim": "intermediate_size",
}
# HF key -> (Mistral key, default value)
top_level_mapping_with_default = {
"model_type": ("model_type", "transformer"),
"hidden_act": ("activation", "silu"),
"tie_word_embeddings": ("tied_embeddings", False),
"max_seq_len": ("max_seq_len", config.get("max_position_embeddings", 128_000)),
"max_position_embeddings": ("max_position_embeddings", 128_000),
}
for key, new_key in config_mapping.items():
if key in config:
config[new_key] = config.pop(key)
for new_key, (key, default_value) in top_level_mapping_with_default.items():
config[new_key] = config.pop(key, default_value)
return config
def _remap_mistral_quantization_args(config: dict) -> dict:
if config.get("quantization"):
quantization = config.pop("quantization", {})
if quantization.get("qformat_weight") == "fp8_e4m3":
qscheme_act = quantization.get("qscheme_act")
assert qscheme_act in ("NO_SCALES", "TENSOR", None), (
"Only NO_SCALES and TENSOR (default) are supported for qscheme_act"
)
is_dynamic = qscheme_act == "NO_SCALES"
config["quantization_config"] = {
"quant_method": "fp8",
"activation_scheme": "dynamic" if is_dynamic else "static",
}
else:
raise ValueError(f"Found unknown quantization='{quantization}' in config")
return config
def _remap_mistral_audio_args(config: dict) -> dict:
whisper_args = config["multimodal"].pop("whisper_model_args")
encoder_args = whisper_args["encoder_args"]
downsample_args = whisper_args["downsample_args"]
quant_config = config.get("quantization_config")
config = {
"model_type": "whixtral",
"architectures": ["VoxtralForConditionalGeneration"],
"text_config": PretrainedConfig.from_dict(config),
"audio_config": WhisperConfig(
num_mel_bins=encoder_args["audio_encoding_args"]["num_mel_bins"],
window_size=encoder_args["audio_encoding_args"]["window_size"],
sampling_rate=encoder_args["audio_encoding_args"]["sampling_rate"],
hop_length=encoder_args["audio_encoding_args"]["hop_length"],
downsample_factor=downsample_args["downsample_factor"],
d_model=encoder_args["dim"],
encoder_layers=encoder_args["n_layers"],
encoder_ffn_dim=encoder_args["hidden_dim"],
encoder_attention_heads=encoder_args["n_heads"],
vocab_size=encoder_args["vocab_size"],
max_source_positions=encoder_args["max_source_positions"],
is_encoder_decoder=False, # Override WhisperConfig default
),
}
if quant_config:
config["quantization_config"] = quant_config
return config
def _remap_moe_args(config: dict) -> dict:
moe_config_map = {
"route_every_n": "moe_layer_freq",
"first_k_dense_replace": "first_k_dense_replace",
"num_experts_per_tok": "num_experts_per_tok",
"num_experts": "n_routed_experts",
"expert_hidden_dim": "moe_intermediate_size",
"routed_scale": "routed_scaling_factor",
"num_shared_experts": "n_shared_experts",
"num_expert_groups": "n_group",
"num_expert_groups_per_tok": "topk_group",
}
moe_config = config.get("moe", {})
for old_name, new_name in moe_config_map.items():
if old_name in moe_config:
value = moe_config.pop(old_name)
config[new_name] = value
config["topk_method"] = None
config["norm_topk_prob"] = True
config["scoring_func"] = "softmax"
return config