236 lines
8.2 KiB
Python
236 lines
8.2 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from typing import Any
|
|
|
|
from transformers import PretrainedConfig, WhisperConfig
|
|
|
|
from vllm.logger import init_logger
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
def adapt_config_dict(
|
|
config_dict: dict[str, Any],
|
|
defaults: dict[str, Any],
|
|
) -> PretrainedConfig:
|
|
config_dict = _remap_general_mistral_args(config_dict)
|
|
|
|
if bool(config_dict.get("quantization")):
|
|
config_dict = _remap_mistral_quantization_args(config_dict)
|
|
|
|
is_moe = bool(config_dict.get("moe"))
|
|
is_mistral_large_3 = (
|
|
is_moe and (config_dict["moe"].get("num_shared_experts") or 0) > 0
|
|
)
|
|
if config_dict.get("model_type") == "mamba":
|
|
config_dict["architectures"] = ["Mamba2ForCausalLM"]
|
|
elif is_moe and is_mistral_large_3:
|
|
config_dict = _remap_moe_args(config_dict)
|
|
config_dict["model_type"] = "deepseek_v3"
|
|
config_dict["architectures"] = ["MistralLarge3ForCausalLM"]
|
|
|
|
assert "llama_4_scaling" in config_dict, (
|
|
"MistralLarge3 expect llama4 scaling config."
|
|
)
|
|
llama_4_scaling_config_keys = ["original_max_position_embeddings", "beta"]
|
|
assert all(
|
|
[
|
|
key in config_dict["llama_4_scaling"]
|
|
for key in llama_4_scaling_config_keys
|
|
]
|
|
), (
|
|
"llama_4_scaling config should define the keys: "
|
|
f"{','.join(llama_4_scaling_config_keys)}"
|
|
)
|
|
elif is_moe:
|
|
config_dict["architectures"] = ["MixtralForCausalLM"]
|
|
else:
|
|
config_dict["architectures"] = ["MistralForCausalLM"]
|
|
|
|
if bool(config_dict.get("yarn")):
|
|
config_dict = _remap_mistral_yarn_args(config_dict)
|
|
|
|
if bool(config_dict.get("llama_4_scaling")):
|
|
llama_4_scaling_config_keys = ["original_max_position_embeddings", "beta"]
|
|
assert all(
|
|
[
|
|
key in config_dict["llama_4_scaling"]
|
|
for key in llama_4_scaling_config_keys
|
|
]
|
|
), (
|
|
"llama_4_scaling config should define the keys: "
|
|
f"{','.join(llama_4_scaling_config_keys)}"
|
|
)
|
|
|
|
is_vision = (config_dict.get("multimodal") or {}).get(
|
|
"vision_encoder_args"
|
|
) or config_dict.get("vision_encoder")
|
|
is_audio = bool(
|
|
((config_dict.get("multimodal") or {}).get("whisper_model_args") or {}).get(
|
|
"encoder_args"
|
|
)
|
|
)
|
|
|
|
assert not (is_vision and is_audio), "Vision and audio are mutually exclusive"
|
|
|
|
if is_vision:
|
|
config_dict = _remap_mistral_vision_args(config_dict)
|
|
if is_audio:
|
|
config_dict = _remap_mistral_audio_args(config_dict)
|
|
|
|
for k, v in defaults.items():
|
|
config_dict.setdefault(k, v)
|
|
|
|
config = PretrainedConfig.from_dict(config_dict)
|
|
|
|
logger.debug("Initialized config %s", config)
|
|
|
|
return config
|
|
|
|
|
|
def _remap_mistral_vision_args(config: dict) -> dict:
|
|
if config.get("multimodal"):
|
|
vision_config = config.pop("multimodal")
|
|
else:
|
|
vision_config = config.pop("vision_encoder")
|
|
|
|
quant_config = config.get("quantization_config")
|
|
config = {
|
|
"model_type": "pixtral",
|
|
"architectures": ["PixtralForConditionalGeneration"],
|
|
"text_config": PretrainedConfig.from_dict(config),
|
|
"vision_config": PretrainedConfig.from_dict(vision_config),
|
|
}
|
|
if quant_config:
|
|
config["quantization_config"] = quant_config
|
|
return config
|
|
|
|
|
|
def _remap_mistral_yarn_args(config: dict) -> dict:
|
|
yarn_config_map = {
|
|
"factor": "factor",
|
|
"original_max_position_embeddings": "original_max_position_embeddings",
|
|
"beta": "beta_fast",
|
|
"alpha": "beta_slow",
|
|
"apply_scale": "apply_yarn_scaling",
|
|
}
|
|
yarn_config = config.get("yarn") or {}
|
|
config["rope_parameters"] = {
|
|
"rope_type": "yarn",
|
|
"mscale_all_dim": 1,
|
|
}
|
|
|
|
if rope_theta := config.pop("rope_theta", None):
|
|
config["rope_parameters"]["rope_theta"] = rope_theta
|
|
|
|
for old_name, new_name in yarn_config_map.items():
|
|
if old_name in yarn_config:
|
|
config["rope_parameters"][new_name] = yarn_config.pop(old_name)
|
|
|
|
assert len(yarn_config) == 0, f"Unparsed yarn config: {yarn_config}"
|
|
|
|
return config
|
|
|
|
|
|
def _remap_general_mistral_args(config: dict) -> dict:
|
|
# Mistral key -> HF key
|
|
config_mapping = {
|
|
"dim": "hidden_size",
|
|
"norm_eps": "rms_norm_eps",
|
|
"n_kv_heads": "num_key_value_heads",
|
|
"n_layers": "num_hidden_layers",
|
|
"n_heads": "num_attention_heads",
|
|
"hidden_dim": "intermediate_size",
|
|
}
|
|
# HF key -> (Mistral key, default value)
|
|
top_level_mapping_with_default = {
|
|
"model_type": ("model_type", "transformer"),
|
|
"hidden_act": ("activation", "silu"),
|
|
"tie_word_embeddings": ("tied_embeddings", False),
|
|
"max_seq_len": ("max_seq_len", config.get("max_position_embeddings", 128_000)),
|
|
"max_position_embeddings": ("max_position_embeddings", 128_000),
|
|
}
|
|
|
|
for key, new_key in config_mapping.items():
|
|
if key in config:
|
|
config[new_key] = config.pop(key)
|
|
|
|
for new_key, (key, default_value) in top_level_mapping_with_default.items():
|
|
config[new_key] = config.pop(key, default_value)
|
|
|
|
return config
|
|
|
|
|
|
def _remap_mistral_quantization_args(config: dict) -> dict:
|
|
if config.get("quantization"):
|
|
quantization = config.pop("quantization", {})
|
|
if quantization.get("qformat_weight") == "fp8_e4m3":
|
|
qscheme_act = quantization.get("qscheme_act")
|
|
assert qscheme_act in ("NO_SCALES", "TENSOR", None), (
|
|
"Only NO_SCALES and TENSOR (default) are supported for qscheme_act"
|
|
)
|
|
is_dynamic = qscheme_act == "NO_SCALES"
|
|
config["quantization_config"] = {
|
|
"quant_method": "fp8",
|
|
"activation_scheme": "dynamic" if is_dynamic else "static",
|
|
}
|
|
else:
|
|
raise ValueError(f"Found unknown quantization='{quantization}' in config")
|
|
|
|
return config
|
|
|
|
|
|
def _remap_mistral_audio_args(config: dict) -> dict:
|
|
whisper_args = config["multimodal"].pop("whisper_model_args")
|
|
encoder_args = whisper_args["encoder_args"]
|
|
downsample_args = whisper_args["downsample_args"]
|
|
|
|
quant_config = config.get("quantization_config")
|
|
config = {
|
|
"model_type": "whixtral",
|
|
"architectures": ["VoxtralForConditionalGeneration"],
|
|
"text_config": PretrainedConfig.from_dict(config),
|
|
"audio_config": WhisperConfig(
|
|
num_mel_bins=encoder_args["audio_encoding_args"]["num_mel_bins"],
|
|
window_size=encoder_args["audio_encoding_args"]["window_size"],
|
|
sampling_rate=encoder_args["audio_encoding_args"]["sampling_rate"],
|
|
hop_length=encoder_args["audio_encoding_args"]["hop_length"],
|
|
downsample_factor=downsample_args["downsample_factor"],
|
|
d_model=encoder_args["dim"],
|
|
encoder_layers=encoder_args["n_layers"],
|
|
encoder_ffn_dim=encoder_args["hidden_dim"],
|
|
encoder_attention_heads=encoder_args["n_heads"],
|
|
vocab_size=encoder_args["vocab_size"],
|
|
max_source_positions=encoder_args["max_source_positions"],
|
|
is_encoder_decoder=False, # Override WhisperConfig default
|
|
),
|
|
}
|
|
if quant_config:
|
|
config["quantization_config"] = quant_config
|
|
return config
|
|
|
|
|
|
def _remap_moe_args(config: dict) -> dict:
|
|
moe_config_map = {
|
|
"route_every_n": "moe_layer_freq",
|
|
"first_k_dense_replace": "first_k_dense_replace",
|
|
"num_experts_per_tok": "num_experts_per_tok",
|
|
"num_experts": "n_routed_experts",
|
|
"expert_hidden_dim": "moe_intermediate_size",
|
|
"routed_scale": "routed_scaling_factor",
|
|
"num_shared_experts": "n_shared_experts",
|
|
"num_expert_groups": "n_group",
|
|
"num_expert_groups_per_tok": "topk_group",
|
|
}
|
|
moe_config = config.get("moe", {})
|
|
for old_name, new_name in moe_config_map.items():
|
|
if old_name in moe_config:
|
|
value = moe_config.pop(old_name)
|
|
config[new_name] = value
|
|
|
|
config["topk_method"] = None
|
|
config["norm_topk_prob"] = True
|
|
config["scoring_func"] = "softmax"
|
|
|
|
return config
|